Effective layer horizon of transformer circuits. The residual stream norm grows exponentially over the forward pass, with a growth rate of about 1.05. Consider the residual stream at layer 0, with norm (say) of 100. Suppose the MLP heads at layer 0 have outputs of norm (say) 5. Then after 30 layers, the residual stream norm will be 100⋅1.0530≈432.2. Then the MLP-0 outputs of norm 5 should have a significantly reduced effect on the computations of MLP-30, due to their smaller relative norm.
On input tokens x, let Attni(x),MLPi(x) be the original model’s sublayer outputs at layer i. I want to think about what happens when the later sublayers can only “see” the last few layers’ worth of outputs.
Definition: Layer-truncated residual stream. A truncated residual stream from layer n1 to layer n2 is formed by the original sublayer outputs from those layers.
hn1:n2(x):=n2∑i=n1Attni(x)+MLPi(x).
Definition: Effective layer horizon. Let k>0 be an integer. Suppose that for all n≥k, we patch in h(n−k):n(x) for the usual residual stream inputs hn(x).[1] Let the effective layer horizon be the smallest k for which the model’s outputs and/or capabilities are “qualitatively unchanged.”
Lastly, slower norm growth probably causes the effective layer horizon to be lower. In that case, simply measuring residual stream norm growth would tell you a lot about the depth of circuits in the model, which could be useful if you want to regularize against that or otherwise decrease it (eg to decrease the amount of effective serial computation).
Do models have an effective layer horizon? If so, what does it tend to be as a function of model depth and other factors—are there scaling laws?
For notational ease, I’m glossing over the fact that we’d be patching in different residual streams for each sublayer of layer n. That is, we wouldn’t patch in the same activations for both the attention and MLP sublayers of layer n.
For example, if a model has an effective layer horizon of 5, then a circuit could run through the whole model because a layer n head could read out features output by a layer n−5 circuit, and then n+5 could read from n…
I believe the horizon may be large because, even if the approximation is fairly good at any particular layer, the errors compound as you go through the layers. If we just apply the horizon at the final output the horizon is smaller.
However, if we apply at just the middle layer (6), the horizon is surprisingly small, so we would expect relatively little error propagated.
But this appears to be an outlier. Compare to 5 and 7.
I realized the previous experiment might be importantly misleading because it’s on a small 12 layer model. In larger models it would still be a big deal if the effective layer horizon was like 20 layers.
Previously the code was too slow to run on larger models. But I made an faster version and ran the same experiment on GPT-2 large (48 layers):
We clearly see the same pattern again. As TurnTrout predicted, there seems be something like an exponential decay in the importance of previous layers as you go futher back. I expect that on large models an the effective layer horizon is an importnat consideration.
[edit: stefan made the same point below earlier than me]
Nice idea! I’m not sure why this would be evidence for residual networks being an ensemble of shallow circuits — it seems more like the opposite to me? If anything, low effective layer horizon implies that later layers are building more on the outputs of intermediate layers. In one extreme, a network with an effective layer horizon of 1 would only consist of circuits that route through every single layer. Likewise, for there to be any extremely shallow circuits that route directly from the inputs to the final layer, the effective layer horizon must be the number of layers in the network.
I do agree that low layer horizons would substantially simplify (in terms of compute) searching for circuits.
I like this idea! I’d love to see checks of this on the SOTA models which tend to have lots of layers (thanks @Joseph Miller for running the GPT2 experiment already!).
I notice this line of argument would also imply that the embedding information can only be accessed up to a certain layer, after which it will be washed out by the high-norm outputs of layers. (And the same for early MLP layers which are rumoured to act as extended embeddings in some models.) -- this seems unexpected.
If the effective layer horizon is 25, then this path cannot work because the output of MLP10 gets lost. In fact, no path with less than 3 modules is possible because there would always be a gap > 25.
Only a less-shallow paths would manage to influence the output of the model
Effective layer horizon of transformer circuits. The residual stream norm grows exponentially over the forward pass, with a growth rate of about 1.05. Consider the residual stream at layer 0, with norm (say) of 100. Suppose the MLP heads at layer 0 have outputs of norm (say) 5. Then after 30 layers, the residual stream norm will be 100⋅1.0530≈432.2. Then the MLP-0 outputs of norm 5 should have a significantly reduced effect on the computations of MLP-30, due to their smaller relative norm.
On input tokens x, let Attni(x),MLPi(x) be the original model’s sublayer outputs at layer i. I want to think about what happens when the later sublayers can only “see” the last few layers’ worth of outputs.
Definition: Layer-truncated residual stream. A truncated residual stream from layer n1 to layer n2 is formed by the original sublayer outputs from those layers.
hn1:n2(x):=n2∑i=n1Attni(x)+MLPi(x).Definition: Effective layer horizon. Let k>0 be an integer. Suppose that for all n≥k, we patch in h(n−k):n(x) for the usual residual stream inputs hn(x).[1] Let the effective layer horizon be the smallest k for which the model’s outputs and/or capabilities are “qualitatively unchanged.”
Effective layer horizons (if they exist) would greatly simplify searches for circuits within models. Additionally, they would be further evidence (but not conclusive[2]) towards hypotheses Residual Networks Behave Like Ensembles of Relatively Shallow Networks.
Lastly, slower norm growth probably causes the effective layer horizon to be lower. In that case, simply measuring residual stream norm growth would tell you a lot about the depth of circuits in the model, which could be useful if you want to regularize against that or otherwise decrease it (eg to decrease the amount of effective serial computation).
Do models have an effective layer horizon? If so, what does it tend to be as a function of model depth and other factors—are there scaling laws?
For notational ease, I’m glossing over the fact that we’d be patching in different residual streams for each sublayer of layer n. That is, we wouldn’t patch in the same activations for both the attention and MLP sublayers of layer n.
For example, if a model has an effective layer horizon of 5, then a circuit could run through the whole model because a layer n head could read out features output by a layer n−5 circuit, and then n+5 could read from n…
Computing the exact layer-truncated residual streams on GPT-2 Small, it seems that the effective layer horizon is quite large:
I’m mean ablating every edge with a source node more than n layers back and calculating the loss on 100 samples from The Pile.
Source code: https://gist.github.com/UFO-101/7b5e27291424029d092d8798ee1a1161
I believe the horizon may be large because, even if the approximation is fairly good at any particular layer, the errors compound as you go through the layers. If we just apply the horizon at the final output the horizon is smaller.
However, if we apply at just the middle layer (6), the horizon is surprisingly small, so we would expect relatively little error propagated.
But this appears to be an outlier. Compare to 5 and 7.
Source: https://gist.github.com/UFO-101/5ba35d88428beb1dab0a254dec07c33b
I realized the previous experiment might be importantly misleading because it’s on a small 12 layer model. In larger models it would still be a big deal if the effective layer horizon was like 20 layers.
Previously the code was too slow to run on larger models. But I made an faster version and ran the same experiment on GPT-2 large (48 layers):
We clearly see the same pattern again. As TurnTrout predicted, there seems be something like an exponential decay in the importance of previous layers as you go futher back. I expect that on large models an the effective layer horizon is an importnat consideration.
Updated source: https://gist.github.com/UFO-101/41b7ff0b250babe69bf16071e76658a6
[edit: stefan made the same point below earlier than me]
Nice idea! I’m not sure why this would be evidence for residual networks being an ensemble of shallow circuits — it seems more like the opposite to me? If anything, low effective layer horizon implies that later layers are building more on the outputs of intermediate layers. In one extreme, a network with an effective layer horizon of 1 would only consist of circuits that route through every single layer. Likewise, for there to be any extremely shallow circuits that route directly from the inputs to the final layer, the effective layer horizon must be the number of layers in the network.
I do agree that low layer horizons would substantially simplify (in terms of compute) searching for circuits.
I like this idea! I’d love to see checks of this on the SOTA models which tend to have lots of layers (thanks @Joseph Miller for running the GPT2 experiment already!).
I notice this line of argument would also imply that the embedding information can only be accessed up to a certain layer, after which it will be washed out by the high-norm outputs of layers. (And the same for early MLP layers which are rumoured to act as extended embeddings in some models.) -- this seems unexpected.
I have the opposite expectation: Effective layer horizons enforce a lower bound on the number of modules involved in a path. Consider the shallow path
Input (layer 0) → MLP 10 → MLP 50 → Output (layer 100)
If the effective layer horizon is 25, then this path cannot work because the output of MLP10 gets lost. In fact, no path with less than 3 modules is possible because there would always be a gap > 25.
Only a less-shallow paths would manage to influence the output of the model
Input (layer 0) → MLP 10 → MLP 30 → MLP 50 → MLP 70 → MLP 90 → Output (layer 100)
This too seems counterintuitive, not sure what to make of this.