In other experiments we’ve run (not presented here), the MSP is not well-represented in the final layer but is instead spread out amongst earlier layers. We think this occurs because in general there are groups of belief states that are degenerate in the sense that they have the same next-token distribution. In that case, the formalism presented in this post says that even though the distinction between those states must be represented in the transformers internal, the transformer is able to lose those distinctions for the purpose of predicting the next token (in the local sense), which occurs most directly right before the unembedding.
Would be interested to see analyses where you show how an MSP is spread out amongst earlier layers.
Presumably, if the model does not discard intermediate results, something like concatenating residual stream vectors from different layers and then linearly correlating with the ground truth belief-state-over-HMM-states vector extracts the same kind of structure you see when looking at the final layer. Maybe even with the same model you analyze, the structure will be crisper if you project the full concatenated-over-layers resid stream, if there is noise in the final layer and the same features are represented more cleanly in earlier layers?
In cases where redundant information is discarded at some point, this is a harder problem of course.
Thanks! I’ll have more thorough results to share about layer-wise reprsentations of the MSP soon. I’ve already run some of the analysis concatenating over all layers residual streams with RRXOR process and it is quite interesting. It seems there’s a lot more to explore with the relationship between number of states in the generative model, number of layers in the transformer, residual stream dimension, and token vocab size. All of these (I think) play some role in how the MSP is represented in the transformer. For RRXOR it is the case that things look crisper when concatenating.
Even for cases where redundant info is discarded, we should be able to see the distinctions somewhere in the transformer. One thing I’m keen on really exploring is such a case, where we can very concretely follow the path/circuit through which redundant info is first distinguished and then is collapsed.
This is really cool work!!
Would be interested to see analyses where you show how an MSP is spread out amongst earlier layers.
Presumably, if the model does not discard intermediate results, something like concatenating residual stream vectors from different layers and then linearly correlating with the ground truth belief-state-over-HMM-states vector extracts the same kind of structure you see when looking at the final layer. Maybe even with the same model you analyze, the structure will be crisper if you project the full concatenated-over-layers resid stream, if there is noise in the final layer and the same features are represented more cleanly in earlier layers?
In cases where redundant information is discarded at some point, this is a harder problem of course.
Thanks! I’ll have more thorough results to share about layer-wise reprsentations of the MSP soon. I’ve already run some of the analysis concatenating over all layers residual streams with RRXOR process and it is quite interesting. It seems there’s a lot more to explore with the relationship between number of states in the generative model, number of layers in the transformer, residual stream dimension, and token vocab size. All of these (I think) play some role in how the MSP is represented in the transformer. For RRXOR it is the case that things look crisper when concatenating.
Even for cases where redundant info is discarded, we should be able to see the distinctions somewhere in the transformer. One thing I’m keen on really exploring is such a case, where we can very concretely follow the path/circuit through which redundant info is first distinguished and then is collapsed.