Oh I hadn’t thought of this, thanks for the comment! I don’t think this apply to Pre-LN Transformers though?
In Pre-LN transformers every layer’s output is directly connected to the residual stream (and thus just one unembedding away from logits), wouldn’t this remove the vanishing gradient problem? I just checkout out the paper you linked, they claim exponentially vanishing gradients is a problem (only) in Post-LN, and how Pre-LN (and their new method) prevent the problem, right?
The residual stream norm curves seem to follow the exponential growth quite precisely, do vanishing gradient problems cause such a clean result? I would have intuitively expected the final weights to look somewhat pathological if they were caused by such a problem in training.
Re prediction: Isn’t the sign the other way around? Vanishing gradients imply growing norms, right? So vanishing gradients in Post-LN would cause gradients to grow exponentially towards later (closer to output) layers (they also plot something like this in Figure 3 in the linked paper).
I agree with the prediction that Post-LN will probably have even stronger exponential norm growth, but I think that this has a different cause to what we find here.
Yep, pre-LN transformers avoid the vanishing gradient problem.
Haven’t checked this myself, but the phenomenon seems to be fairly clean? See figure 3.b in the paper I linked, or figure 1 in this paper.
I actually wouldn’t think of vanishing/exploding gradients as a pathological training problem but a more general phenomenon about any dynamical system. Some dynamical systems (e.g. the sigmoid map) fall into equilibria over time, getting exponentially close to one. Other dynamical systems (e.g. the logistic map) become chaotic, and similar trajectories diverge exponentially over time. If you check, you’ll find the first kind leads to vanishing gradients (at each iteration of the map), and the second to exploding ones. This a forward pass perspective on the problem—the usual perspective on the problem considers only implications for the backward pass, since that’s where the problem usually shows up.
Notice above that the system with exponential decay in the forward pass had vanishing gradients (growing gradient norms) in the backward pass—the relationship is inverse. If you start with toy single-neuron networks, you can prove this to yourself pretty easily.
The predictions here are still complicated by a few facts—first, exponential divergence/convergence of trajectories doesn’t necessarily imply exponentially growing/shrinking norms. Second, the layer norm complicates things, confining some dynamics to a hypersphere (modulo the zero-mean part). Haven’t fully worked out the problem for myself yet, but still think there’s a relationship here.
Oh I hadn’t thought of this, thanks for the comment! I don’t think this apply to Pre-LN Transformers though?
In Pre-LN transformers every layer’s output is directly connected to the residual stream (and thus just one unembedding away from logits), wouldn’t this remove the vanishing gradient problem? I just checkout out the paper you linked, they claim exponentially vanishing gradients is a problem (only) in Post-LN, and how Pre-LN (and their new method) prevent the problem, right?
The residual stream norm curves seem to follow the exponential growth quite precisely, do vanishing gradient problems cause such a clean result? I would have intuitively expected the final weights to look somewhat pathological if they were caused by such a problem in training.
Re prediction: Isn’t the sign the other way around? Vanishing gradients imply growing norms, right? So vanishing gradients in Post-LN would cause gradients to grow exponentially towards later (closer to output) layers (they also plot something like this in Figure 3 in the linked paper). I agree with the prediction that Post-LN will probably have even stronger exponential norm growth, but I think that this has a different cause to what we find here.
Yep, pre-LN transformers avoid the vanishing gradient problem.
Haven’t checked this myself, but the phenomenon seems to be fairly clean? See figure 3.b in the paper I linked, or figure 1 in this paper.
I actually wouldn’t think of vanishing/exploding gradients as a pathological training problem but a more general phenomenon about any dynamical system. Some dynamical systems (e.g. the sigmoid map) fall into equilibria over time, getting exponentially close to one. Other dynamical systems (e.g. the logistic map) become chaotic, and similar trajectories diverge exponentially over time. If you check, you’ll find the first kind leads to vanishing gradients (at each iteration of the map), and the second to exploding ones. This a forward pass perspective on the problem—the usual perspective on the problem considers only implications for the backward pass, since that’s where the problem usually shows up.
Notice above that the system with exponential decay in the forward pass had vanishing gradients (growing gradient norms) in the backward pass—the relationship is inverse. If you start with toy single-neuron networks, you can prove this to yourself pretty easily.
The predictions here are still complicated by a few facts—first, exponential divergence/convergence of trajectories doesn’t necessarily imply exponentially growing/shrinking norms. Second, the layer norm complicates things, confining some dynamics to a hypersphere (modulo the zero-mean part). Haven’t fully worked out the problem for myself yet, but still think there’s a relationship here.