Yep, pre-LN transformers avoid the vanishing gradient problem.
Haven’t checked this myself, but the phenomenon seems to be fairly clean? See figure 3.b in the paper I linked, or figure 1 in this paper.
I actually wouldn’t think of vanishing/exploding gradients as a pathological training problem but a more general phenomenon about any dynamical system. Some dynamical systems (e.g. the sigmoid map) fall into equilibria over time, getting exponentially close to one. Other dynamical systems (e.g. the logistic map) become chaotic, and similar trajectories diverge exponentially over time. If you check, you’ll find the first kind leads to vanishing gradients (at each iteration of the map), and the second to exploding ones. This a forward pass perspective on the problem—the usual perspective on the problem considers only implications for the backward pass, since that’s where the problem usually shows up.
Notice above that the system with exponential decay in the forward pass had vanishing gradients (growing gradient norms) in the backward pass—the relationship is inverse. If you start with toy single-neuron networks, you can prove this to yourself pretty easily.
The predictions here are still complicated by a few facts—first, exponential divergence/convergence of trajectories doesn’t necessarily imply exponentially growing/shrinking norms. Second, the layer norm complicates things, confining some dynamics to a hypersphere (modulo the zero-mean part). Haven’t fully worked out the problem for myself yet, but still think there’s a relationship here.
Yep, pre-LN transformers avoid the vanishing gradient problem.
Haven’t checked this myself, but the phenomenon seems to be fairly clean? See figure 3.b in the paper I linked, or figure 1 in this paper.
I actually wouldn’t think of vanishing/exploding gradients as a pathological training problem but a more general phenomenon about any dynamical system. Some dynamical systems (e.g. the sigmoid map) fall into equilibria over time, getting exponentially close to one. Other dynamical systems (e.g. the logistic map) become chaotic, and similar trajectories diverge exponentially over time. If you check, you’ll find the first kind leads to vanishing gradients (at each iteration of the map), and the second to exploding ones. This a forward pass perspective on the problem—the usual perspective on the problem considers only implications for the backward pass, since that’s where the problem usually shows up.
Notice above that the system with exponential decay in the forward pass had vanishing gradients (growing gradient norms) in the backward pass—the relationship is inverse. If you start with toy single-neuron networks, you can prove this to yourself pretty easily.
The predictions here are still complicated by a few facts—first, exponential divergence/convergence of trajectories doesn’t necessarily imply exponentially growing/shrinking norms. Second, the layer norm complicates things, confining some dynamics to a hypersphere (modulo the zero-mean part). Haven’t fully worked out the problem for myself yet, but still think there’s a relationship here.