@Oliver Daniels-Koch’s reply to my comment made me read this post again more carefully and now I think that that your formulation of the residual expansion is incorrect.
Given T=(MLP+Id)∘(Att+Id) it does not follow that T=(MLP∘Att)+(MLP)+(Att)+Id because MLP is a non-linear operation. It cannot be decomposed like this.
My understanding of your big summation (with C representing any MLP or attention head):
again does not hold because the Cs are non-linear.
There are two similar ideas which do hold, namely (1) the treeified / unraveled view and (2) the factorized view (both of which are illustrated in figure 1 here), but your residual expansion / big summation is not equivalent to either.
The treeified / unraveled view is the most similar. It separates each path from input to output, but the difference is that this does not claim that the output is the sum of all separate paths.
The factorized view follows from treeified view and is just the observation that any point in the residual stream can be decomposed into the outputs of all previous components.
That makes sense to me. I guess I’m dissatisfied here because the idea of an ensemble seems to be that individual components in the ensemble are independent; whereas in the unraveled view of a residual network, different paths still interact with each other (e.g. if two paths overlap, then ablating one of them could also (in principle) change the value computed by the other path). This seems to be the mechanism that explains redundancy.
@Oliver Daniels-Koch’s reply to my comment made me read this post again more carefully and now I think that that your formulation of the residual expansion is incorrect.
Given T=(MLP+Id)∘(Att+Id) it does not follow that T=(MLP∘Att)+(MLP)+(Att)+Id because MLP is a non-linear operation. It cannot be decomposed like this.
My understanding of your big summation (with C representing any MLP or attention head):
T(x)=x+n−1∑i=0Ci(x)+n−1∑i=1i−1∑j=0Ci∘Cj(x)+n−1∑i=2i−1∑j=1j−1∑k=0Ci∘Cj∘Ck(x)+⋯again does not hold because the Cs are non-linear.
There are two similar ideas which do hold, namely (1) the treeified / unraveled view and (2) the factorized view (both of which are illustrated in figure 1 here), but your residual expansion / big summation is not equivalent to either.
The treeified / unraveled view is the most similar. It separates each path from input to output, but the difference is that this does not claim that the output is the sum of all separate paths.
The factorized view follows from treeified view and is just the observation that any point in the residual stream can be decomposed into the outputs of all previous components.
T(x)=x+n−1∑i=0Output(Ci)
If I understand correctly, you’re saying that my expansion is wrong, because MLP∘(Att+Id)≠MLP∘Att+MLP∘Id, which I agree with.
Then isn’t it also true that Att∘(MLP+Id)≠Att∘MLP+Att∘Id
Also, if the output is not a sum of all separate paths, then what’s the point of the unraveled view?
Yes MLP∘(Att+Id)≠MLP∘Att+MLP∘Id is what I’m saying.
Yes I agree Att∘(MLP+Id)≠Att∘MLP+Att∘Id
(Firstly note that it can be true without being useful). In the Residual Networks Behave Like Ensembles of Relatively Shallow Networks paper, they discover that long paths are mostly not needed for the model. In Causal Scrubbing they intervene on the treeified view to understand which paths are causally relevant for particular behaviors.
That makes sense to me. I guess I’m dissatisfied here because the idea of an ensemble seems to be that individual components in the ensemble are independent; whereas in the unraveled view of a residual network, different paths still interact with each other (e.g. if two paths overlap, then ablating one of them could also (in principle) change the value computed by the other path). This seems to be the mechanism that explains redundancy.