I agree, this seems like exactly the same thing, which is great! In hindsight it’s not surprising that you / other people have already thought about this
Do you think the ‘tree-ified view’ (to use your name for it) is a good abstraction for thinking about how a model works? Are individual terms in the expansion the right unit of analysis?
Just to make it explicit and check my understanding—the residual decomposition is equivalent to edge / factorized view of the transformer in that we can express any term in the residual decomposition as a set of edges that form a path from input to output, e.g
And it follows that the (pre final layernorm) output of a transformer is the sum of all the “paths” from input to output constructed from the factorized DAG.
@Oliver Daniels-Koch’s reply to my comment made me read this post again more carefully and now I think that that your formulation of the residual expansion is incorrect.
Given T=(MLP+Id)∘(Att+Id) it does not follow that T=(MLP∘Att)+(MLP)+(Att)+Id because MLP is a non-linear operation. It cannot be decomposed like this.
My understanding of your big summation (with C representing any MLP or attention head):
again does not hold because the Cs are non-linear.
There are two similar ideas which do hold, namely (1) the treeified / unraveled view and (2) the factorized view (both of which are illustrated in figure 1 here), but your residual expansion / big summation is not equivalent to either.
The treeified / unraveled view is the most similar. It separates each path from input to output, but the difference is that this does not claim that the output is the sum of all separate paths.
The factorized view follows from treeified view and is just the observation that any point in the residual stream can be decomposed into the outputs of all previous components.
That makes sense to me. I guess I’m dissatisfied here because the idea of an ensemble seems to be that individual components in the ensemble are independent; whereas in the unraveled view of a residual network, different paths still interact with each other (e.g. if two paths overlap, then ablating one of them could also (in principle) change the value computed by the other path). This seems to be the mechanism that explains redundancy.
I am not sure how new this approach is (for simplified Transformers, the original AMFOTC paper has several sections called “* Path Expansion *”, which seem to do something very similar for a reduced set of transformations, and their formalism of “virtual attention heads” seems also to be in that spirit).
Fair point, and I should amend the post to point out that AMFOTC also does ‘path expansion’. However, I think this is still conceptually distinct from AMFOTC because:
In my reading of AMFOTC, the focus seems to be on understanding attention by separating the QK and OV circuits, writing these as linear (or almost linear) terms, and fleshing this out for 1-2 layer attention-only transformers. This is cool, but also very hard to use at the level of a full model
Beyond understanding individual attention heads, I am more interested in how the whole model works; IMO this is very unlikely to be simply understood as a sum of linear components. OTOH residual expansion gives a sum of nonlinear components and maybe each of those things is more interpretable.
I think the notion of path ‘degrees’ hasn’t been explicitly stated before and I found this to be a useful abstraction to think about circuit complexity.
maybe this post is better framed as ‘reconciling AMFOTC with SAE circuit analysis’.
Here is one aspect which might be useful to keep in mind.
If we think about all this as some kind of “generalized Taylor expansion”, there are some indications that the deviations from linearity might be small.
Another indication pointing to “almost linearity” is that “model merge” works pretty well. Although, interestingly enough, people often prefer to approach “model merge” in a more subtle fashion than just linear interpolation, so, presumably, non-linearity does matter quite a bit as well, e.g. https://huggingface.co/blog/mlabonne/merge-models.
You might enjoy this paper about a related idea.
If I understand correctly, this residual decomposition is equivalent to the edge / factorized view of a transformer described here.
Update: actually the residual decomposition is incorrect—see my other comment.
I agree, this seems like exactly the same thing, which is great! In hindsight it’s not surprising that you / other people have already thought about this
Do you think the ‘tree-ified view’ (to use your name for it) is a good abstraction for thinking about how a model works? Are individual terms in the expansion the right unit of analysis?
The treeified view is different from the factorized view! See figure 1 here.
I think the factorized view is pretty useful. But on other hand I think MLP + Attention Head circuits are too coarse-grained to be that interpretable.
Just to make it explicit and check my understanding—the residual decomposition is equivalent to edge / factorized view of the transformer in that we can express any term in the residual decomposition as a set of edges that form a path from input to output, e.g
Id = input → output
(Attn34∘MLP2∘Att01) = input-> Attn 1.0 → MLP 2 → Attn 4.3 → output
And it follows that the (pre final layernorm) output of a transformer is the sum of all the “paths” from input to output constructed from the factorized DAG.
Actually I think the residual decomposition is incorrect—see my other comment.
@Oliver Daniels-Koch’s reply to my comment made me read this post again more carefully and now I think that that your formulation of the residual expansion is incorrect.
Given T=(MLP+Id)∘(Att+Id) it does not follow that T=(MLP∘Att)+(MLP)+(Att)+Id because MLP is a non-linear operation. It cannot be decomposed like this.
My understanding of your big summation (with C representing any MLP or attention head):
T(x)=x+n−1∑i=0Ci(x)+n−1∑i=1i−1∑j=0Ci∘Cj(x)+n−1∑i=2i−1∑j=1j−1∑k=0Ci∘Cj∘Ck(x)+⋯again does not hold because the Cs are non-linear.
There are two similar ideas which do hold, namely (1) the treeified / unraveled view and (2) the factorized view (both of which are illustrated in figure 1 here), but your residual expansion / big summation is not equivalent to either.
The treeified / unraveled view is the most similar. It separates each path from input to output, but the difference is that this does not claim that the output is the sum of all separate paths.
The factorized view follows from treeified view and is just the observation that any point in the residual stream can be decomposed into the outputs of all previous components.
T(x)=x+n−1∑i=0Output(Ci)
If I understand correctly, you’re saying that my expansion is wrong, because MLP∘(Att+Id)≠MLP∘Att+MLP∘Id, which I agree with.
Then isn’t it also true that Att∘(MLP+Id)≠Att∘MLP+Att∘Id
Also, if the output is not a sum of all separate paths, then what’s the point of the unraveled view?
Yes MLP∘(Att+Id)≠MLP∘Att+MLP∘Id is what I’m saying.
Yes I agree Att∘(MLP+Id)≠Att∘MLP+Att∘Id
(Firstly note that it can be true without being useful). In the Residual Networks Behave Like Ensembles of Relatively Shallow Networks paper, they discover that long paths are mostly not needed for the model. In Causal Scrubbing they intervene on the treeified view to understand which paths are causally relevant for particular behaviors.
That makes sense to me. I guess I’m dissatisfied here because the idea of an ensemble seems to be that individual components in the ensemble are independent; whereas in the unraveled view of a residual network, different paths still interact with each other (e.g. if two paths overlap, then ablating one of them could also (in principle) change the value computed by the other path). This seems to be the mechanism that explains redundancy.
I think this makes sense.
I am not sure how new this approach is (for simplified Transformers, the original AMFOTC paper has several sections called “* Path Expansion *”, which seem to do something very similar for a reduced set of transformations, and their formalism of “virtual attention heads” seems also to be in that spirit).
Fair point, and I should amend the post to point out that AMFOTC also does ‘path expansion’. However, I think this is still conceptually distinct from AMFOTC because:
In my reading of AMFOTC, the focus seems to be on understanding attention by separating the QK and OV circuits, writing these as linear (or almost linear) terms, and fleshing this out for 1-2 layer attention-only transformers. This is cool, but also very hard to use at the level of a full model
Beyond understanding individual attention heads, I am more interested in how the whole model works; IMO this is very unlikely to be simply understood as a sum of linear components. OTOH residual expansion gives a sum of nonlinear components and maybe each of those things is more interpretable.
I think the notion of path ‘degrees’ hasn’t been explicitly stated before and I found this to be a useful abstraction to think about circuit complexity.
maybe this post is better framed as ‘reconciling AMFOTC with SAE circuit analysis’.
Yes, I think this makes sense.
Here is one aspect which might be useful to keep in mind.
If we think about all this as some kind of “generalized Taylor expansion”, there are some indications that the deviations from linearity might be small.
E.g. there is this rather famous post, https://www.lesswrong.com/posts/JK9nxcBhQfzEgjjqe/deep-learning-models-might-be-secretly-almost-linear.
Another indication pointing to “almost linearity” is that “model merge” works pretty well. Although, interestingly enough, people often prefer to approach “model merge” in a more subtle fashion than just linear interpolation, so, presumably, non-linearity does matter quite a bit as well, e.g. https://huggingface.co/blog/mlabonne/merge-models.