We do this by performing standard linear regression from the residual stream activations (64 dimensional vectors) to the belief distributions (3 dimensional vectors) which associated with them in the MSP.
I don’t understand how we go from this to the fractal. The linear probe gives us a single 2D point for every forward pass of the transformer, correct? How do we get the picture with many points in it? Is it by sampling from the transformer while reading the probe after every token and then putting all the points from that on one graph?
Is this result equivalent to saying “a transformer trained on an HMM’s output learns a linear representation of the probability distribution over the HMM’s states”?
Yep, that’s what I was trying to describe as well. Thanks!