Exponential growth is a fairly natural thing to expect here, roughly for the same reason that vanishing/exploding gradients happen (input/output sensitivity is directly related to param/output sensitivity). Based on this hypothesis, I’m preregistering the prediction that (all other things equal) the residual stream in post-LN transformers will exhibit exponentially shrinking norms, since it’s known that post-LN transformers are more sensitive to vanishing gradient problems compared to pre-LN ones.
Edit: On further thought, I still think this intuition is correct, but I expect the prediction is wrong—the notion of relative residual stream size in a post-LN transformer is a bit dubious, since the size of the residual stream is entirely determined by the layer norm constants, which are a bit arbitrary because they can be rolled into other weights. I think the proper prediction is more around something like Lyapunov exponents.
Oh I hadn’t thought of this, thanks for the comment! I don’t think this apply to Pre-LN Transformers though?
In Pre-LN transformers every layer’s output is directly connected to the residual stream (and thus just one unembedding away from logits), wouldn’t this remove the vanishing gradient problem? I just checkout out the paper you linked, they claim exponentially vanishing gradients is a problem (only) in Post-LN, and how Pre-LN (and their new method) prevent the problem, right?
The residual stream norm curves seem to follow the exponential growth quite precisely, do vanishing gradient problems cause such a clean result? I would have intuitively expected the final weights to look somewhat pathological if they were caused by such a problem in training.
Re prediction: Isn’t the sign the other way around? Vanishing gradients imply growing norms, right? So vanishing gradients in Post-LN would cause gradients to grow exponentially towards later (closer to output) layers (they also plot something like this in Figure 3 in the linked paper).
I agree with the prediction that Post-LN will probably have even stronger exponential norm growth, but I think that this has a different cause to what we find here.
Yep, pre-LN transformers avoid the vanishing gradient problem.
Haven’t checked this myself, but the phenomenon seems to be fairly clean? See figure 3.b in the paper I linked, or figure 1 in this paper.
I actually wouldn’t think of vanishing/exploding gradients as a pathological training problem but a more general phenomenon about any dynamical system. Some dynamical systems (e.g. the sigmoid map) fall into equilibria over time, getting exponentially close to one. Other dynamical systems (e.g. the logistic map) become chaotic, and similar trajectories diverge exponentially over time. If you check, you’ll find the first kind leads to vanishing gradients (at each iteration of the map), and the second to exploding ones. This a forward pass perspective on the problem—the usual perspective on the problem considers only implications for the backward pass, since that’s where the problem usually shows up.
Notice above that the system with exponential decay in the forward pass had vanishing gradients (growing gradient norms) in the backward pass—the relationship is inverse. If you start with toy single-neuron networks, you can prove this to yourself pretty easily.
The predictions here are still complicated by a few facts—first, exponential divergence/convergence of trajectories doesn’t necessarily imply exponentially growing/shrinking norms. Second, the layer norm complicates things, confining some dynamics to a hypersphere (modulo the zero-mean part). Haven’t fully worked out the problem for myself yet, but still think there’s a relationship here.
Exponential growth is a fairly natural thing to expect here, roughly for the same reason that vanishing/exploding gradients happen (input/output sensitivity is directly related to param/output sensitivity). Based on this hypothesis, I’m preregistering the prediction that (all other things equal) the residual stream in post-LN transformers will exhibit exponentially shrinking norms, since it’s known that post-LN transformers are more sensitive to vanishing gradient problems compared to pre-LN ones.
Edit: On further thought, I still think this intuition is correct, but I expect the prediction is wrong—the notion of relative residual stream size in a post-LN transformer is a bit dubious, since the size of the residual stream is entirely determined by the layer norm constants, which are a bit arbitrary because they can be rolled into other weights. I think the proper prediction is more around something like Lyapunov exponents.
Oh I hadn’t thought of this, thanks for the comment! I don’t think this apply to Pre-LN Transformers though?
In Pre-LN transformers every layer’s output is directly connected to the residual stream (and thus just one unembedding away from logits), wouldn’t this remove the vanishing gradient problem? I just checkout out the paper you linked, they claim exponentially vanishing gradients is a problem (only) in Post-LN, and how Pre-LN (and their new method) prevent the problem, right?
The residual stream norm curves seem to follow the exponential growth quite precisely, do vanishing gradient problems cause such a clean result? I would have intuitively expected the final weights to look somewhat pathological if they were caused by such a problem in training.
Re prediction: Isn’t the sign the other way around? Vanishing gradients imply growing norms, right? So vanishing gradients in Post-LN would cause gradients to grow exponentially towards later (closer to output) layers (they also plot something like this in Figure 3 in the linked paper). I agree with the prediction that Post-LN will probably have even stronger exponential norm growth, but I think that this has a different cause to what we find here.
Yep, pre-LN transformers avoid the vanishing gradient problem.
Haven’t checked this myself, but the phenomenon seems to be fairly clean? See figure 3.b in the paper I linked, or figure 1 in this paper.
I actually wouldn’t think of vanishing/exploding gradients as a pathological training problem but a more general phenomenon about any dynamical system. Some dynamical systems (e.g. the sigmoid map) fall into equilibria over time, getting exponentially close to one. Other dynamical systems (e.g. the logistic map) become chaotic, and similar trajectories diverge exponentially over time. If you check, you’ll find the first kind leads to vanishing gradients (at each iteration of the map), and the second to exploding ones. This a forward pass perspective on the problem—the usual perspective on the problem considers only implications for the backward pass, since that’s where the problem usually shows up.
Notice above that the system with exponential decay in the forward pass had vanishing gradients (growing gradient norms) in the backward pass—the relationship is inverse. If you start with toy single-neuron networks, you can prove this to yourself pretty easily.
The predictions here are still complicated by a few facts—first, exponential divergence/convergence of trajectories doesn’t necessarily imply exponentially growing/shrinking norms. Second, the layer norm complicates things, confining some dynamics to a hypersphere (modulo the zero-mean part). Haven’t fully worked out the problem for myself yet, but still think there’s a relationship here.