Edit: The math here has turned out to be wrong. See Joseph Miller’s reply here. I will revise the main content of this post at some point to reflect this.
This is an informal note describing my current approach for thinking about transformer circuits. I’ve not spent a lot of time thinking deeply about this but I believe the overall claims here are correct.
Note: A lot of the high-level ideas I include here are not really original, but I haven’t seen the specific framing here applied to transformers, and I would like more people to think about this / tell me whether this is obviously flawed in some way.
AMFOTC and its Limitations
A Mathematical Framework is one of my favourite mech interp papers ever, and has spawned a very successful subfield of circuit analysis. I really like it because it provides a general framework for how to think about transformers and circuits.
However, there are some notable limitations to this framework:
It doesn’t consider MLP blocks
It doesn’t consider layer-norm
The analysis stops at 2-layer transformers
possibly because it becomes too onerous to write the equations for larger models
which in turn suggests that better abstraction / notation is needed
Also, this framework centralises on the ‘model basis’ (attention heads, residual stream) and fails to incorporate other ideas (superposition, SAEs).
So I spent some time thinking about how we might extend this framework and here’s what I came up with.
The Residual Expansion
A very old idea in machine learning, dating all the way back to ResNets, is that a sequence of residual operations can be ‘expanded out’ into a set of feedforward operations
These feedforward components act like an ensemble, in that they sum together to exactly reconstruct the model’s output
To make this concrete, let’s consider a 1-layer transformer with attention and MLP blocks.
Let MLP denotes the MLP block and Att denotes the attention block.
We can write this as: T=(MLP+Id)∘(Att+Id)
We can convert product-of-sum into a sum-of-products: T=(MLP∘Att)+(MLP)+(Att)+Id
Each of the four terms represents a single (nonlinear) computational path through the model
More generally, for an N-layer transformer, we can write a big summation of terms:
T=(Id)+(n−1∑i=0MLPi+n−1∑i=0Atti)+⋯
The first bracket contains degree 0 terms. This is simply the identity, i.e. a no-op on the token embeddings
The second bracket contains degree 1 terms. This captures the direct effect of each block in isolation
And so on…
Generally there will be terms of up to degree 2N, since there are 2 blocks per layer
What does this get us?
The residual decomposition gives us a sum of feedforward paths through the model, each of which is nonlinear.
Maybe these paths are interpretable!
Circuits. A circuit could possibly be represented as a sum of a small number of these feedforward paths.
An appropriate summation over paths can be “re-factorized” into a graph
Exercise 1 (for me, and for others): Write the IOI circuit in the residual decomposition
AMFOTC. The residual decomposition is fully compatible with AMFOTC
AMFOTC simply explains how to decompose a single attention operation into linear (or almost-linear) terms
The residual decomposition tells you how to build lots of attention operations up into a bigger picture
Exercise 2 (for me, and for others): Write “bigrams” and “skip-trigrams” using the residual expansion
Individual terms. Generally, many other ideas in interpretability can be thought of as attempts to understand individual terms in the residual decomposition
E.g. MLP transcoders decompose the MLP block into a sparsely-activating set of paths through SAE features.
This decomposition can be ‘substituted back’ into the residual decomposition to yield finer-grained circuits.
Ditto for attention-out SAEs, bilinear SAEs, etc.
If we take “SAE-fication” to its logical extreme by replacing everything with SAEs, we get sparse feature circuits
Other remarks
Layernorm. We’ve handwaved layernorm here, but it’s simply to modify our equations to account for it. Notably, because layernorm is not a residual operation, we don’t end up with any more terms than we did originally
Tl;dr I think this is a nice unifying way to think about lots of circuit analysis work.
Open Questions / Ideas
Here’s some ideas motivated by this line of thinking.
Hybrid paths / circuits. As discussed above, if we take ‘paths’ through the model to be the unit of analysis, we can easily decompose paths (in the model space) further using SAE features. I’m generally very excited to work on more extensive efforts to do circuit analysis using such paths.
Path attributions. How much does each path contribute to the output? Do we notice patterns?
Pick a reasonable notion of ‘how much does X path contribute to the output’ (e.g. attributions, Shapley values, etc)
Task complexity. The degree of a path (i.e. the number of blocks it uses) could be used as a rough notion of ‘task complexity’. Does this align with our intuition about what tasks are more or less complex? E.g. I would expect IOI and Docstrings to be more complex (as they’re somewhat algorithmic) than other tasks like gender bias and hypernymy (as they’re simply variants of lookup)
Parallel processing. Lucius Bushnaq has argued to me that different blocks of the model could be implementing different parts of a higher-level operation. E.g. to calculate AND(a,b), two different MLP blocks may be responsible for computing a and b respectively. If so we’d expect them not to interact in any way. Then we could formalise this as M1+M2=M1∘M2
Conclusion
AMFOTC is great, but doesn’t go far enough IMO, and is outdated
‘Residual expansion’ explains how to extend AMFOTC to full transformers and fit it in with SAE-based circuit analysis
The Residual Expansion: A Framework for thinking about Transformer Circuits
Edit: The math here has turned out to be wrong. See Joseph Miller’s reply here. I will revise the main content of this post at some point to reflect this.
This is an informal note describing my current approach for thinking about transformer circuits. I’ve not spent a lot of time thinking deeply about this but I believe the overall claims here are correct.
Note: A lot of the high-level ideas I include here are not really original, but I haven’t seen the specific framing here applied to transformers, and I would like more people to think about this / tell me whether this is obviously flawed in some way.
AMFOTC and its Limitations
A Mathematical Framework is one of my favourite mech interp papers ever, and has spawned a very successful subfield of circuit analysis. I really like it because it provides a general framework for how to think about transformers and circuits.
However, there are some notable limitations to this framework:
It doesn’t consider MLP blocks
It doesn’t consider layer-norm
The analysis stops at 2-layer transformers
possibly because it becomes too onerous to write the equations for larger models
which in turn suggests that better abstraction / notation is needed
Also, this framework centralises on the ‘model basis’ (attention heads, residual stream) and fails to incorporate other ideas (superposition, SAEs).
So I spent some time thinking about how we might extend this framework and here’s what I came up with.
The Residual Expansion
A very old idea in machine learning, dating all the way back to ResNets, is that a sequence of residual operations can be ‘expanded out’ into a set of feedforward operations
These feedforward components act like an ensemble, in that they sum together to exactly reconstruct the model’s output
To make this concrete, let’s consider a 1-layer transformer with attention and MLP blocks.
Let MLP denotes the MLP block and Att denotes the attention block.
We can write this as: T=(MLP+Id)∘(Att+Id)
We can convert product-of-sum into a sum-of-products: T=(MLP∘Att)+(MLP)+(Att)+Id
Each of the four terms represents a single (nonlinear) computational path through the model
More generally, for an N-layer transformer, we can write a big summation of terms:
T=(Id)+(n−1∑i=0MLPi+n−1∑i=0Atti)+⋯
The first bracket contains degree 0 terms. This is simply the identity, i.e. a no-op on the token embeddings
The second bracket contains degree 1 terms. This captures the direct effect of each block in isolation
And so on…
Generally there will be terms of up to degree 2N, since there are 2 blocks per layer
What does this get us?
The residual decomposition gives us a sum of feedforward paths through the model, each of which is nonlinear.
Maybe these paths are interpretable!
Circuits. A circuit could possibly be represented as a sum of a small number of these feedforward paths.
An appropriate summation over paths can be “re-factorized” into a graph
Exercise 1 (for me, and for others): Write the IOI circuit in the residual decomposition
AMFOTC. The residual decomposition is fully compatible with AMFOTC
AMFOTC simply explains how to decompose a single attention operation into linear (or almost-linear) terms
The residual decomposition tells you how to build lots of attention operations up into a bigger picture
Exercise 2 (for me, and for others): Write “bigrams” and “skip-trigrams” using the residual expansion
Individual terms. Generally, many other ideas in interpretability can be thought of as attempts to understand individual terms in the residual decomposition
E.g. MLP transcoders decompose the MLP block into a sparsely-activating set of paths through SAE features.
This decomposition can be ‘substituted back’ into the residual decomposition to yield finer-grained circuits.
Ditto for attention-out SAEs, bilinear SAEs, etc.
If we take “SAE-fication” to its logical extreme by replacing everything with SAEs, we get sparse feature circuits
Other remarks
Layernorm. We’ve handwaved layernorm here, but it’s simply to modify our equations to account for it. Notably, because layernorm is not a residual operation, we don’t end up with any more terms than we did originally
Tl;dr I think this is a nice unifying way to think about lots of circuit analysis work.
Open Questions / Ideas
Here’s some ideas motivated by this line of thinking.
Hybrid paths / circuits. As discussed above, if we take ‘paths’ through the model to be the unit of analysis, we can easily decompose paths (in the model space) further using SAE features. I’m generally very excited to work on more extensive efforts to do circuit analysis using such paths.
Path attributions. How much does each path contribute to the output? Do we notice patterns?
Pick a reasonable notion of ‘how much does X path contribute to the output’ (e.g. attributions, Shapley values, etc)
Task complexity. The degree of a path (i.e. the number of blocks it uses) could be used as a rough notion of ‘task complexity’. Does this align with our intuition about what tasks are more or less complex? E.g. I would expect IOI and Docstrings to be more complex (as they’re somewhat algorithmic) than other tasks like gender bias and hypernymy (as they’re simply variants of lookup)
Parallel processing. Lucius Bushnaq has argued to me that different blocks of the model could be implementing different parts of a higher-level operation. E.g. to calculate AND(a,b), two different MLP blocks may be responsible for computing a and b respectively. If so we’d expect them not to interact in any way. Then we could formalise this as M1+M2=M1∘M2
Conclusion
AMFOTC is great, but doesn’t go far enough IMO, and is outdated
‘Residual expansion’ explains how to extend AMFOTC to full transformers and fit it in with SAE-based circuit analysis
I’m very interested to hear takes on this!