We do this by performing standard linear regression from the residual stream activations (64 dimensional vectors) to the belief distributions (3 dimensional vectors) which associated with them in the MSP.
I don’t understand how we go from this to the fractal. The linear probe gives us a single 2D point for every forward pass of the transformer, correct? How do we get the picture with many points in it? Is it by sampling from the transformer while reading the probe after every token and then putting all the points from that on one graph?
Is this result equivalent to saying “a transformer trained on an HMM’s output learns a linear representation of the probability distribution over the HMM’s states”?
For every input into the transformer (of every length up to the context window length), we know the ground truth belief state that comp mech says an observer should have over the HMM states. In this case, this is 3 numbers. So for each input we have a 3d ground truth vector. Also, for each input we have the residual stream activation (in this case a 64D vector). To find the projection we just use standard Linear Regression (as implemented in sklearn) between the 64D residual stream vectors and the 3D (really 2D) ground truth vectors. Does that make sense?
Given that the model eventually outputs the next token, shouldn’t the final embedding matrix be exactly your linear fit matrix multiplied by the probability of each state to output a given token? Could you use that?
I don’t understand how we go from this to the fractal. The linear probe gives us a single 2D point for every forward pass of the transformer, correct? How do we get the picture with many points in it? Is it by sampling from the transformer while reading the probe after every token and then putting all the points from that on one graph?
Is this result equivalent to saying “a transformer trained on an HMM’s output learns a linear representation of the probability distribution over the HMM’s states”?
I should have explained this better in my post.
For every input into the transformer (of every length up to the context window length), we know the ground truth belief state that comp mech says an observer should have over the HMM states. In this case, this is 3 numbers. So for each input we have a 3d ground truth vector. Also, for each input we have the residual stream activation (in this case a 64D vector). To find the projection we just use standard Linear Regression (as implemented in sklearn) between the 64D residual stream vectors and the 3D (really 2D) ground truth vectors. Does that make sense?
Given that the model eventually outputs the next token, shouldn’t the final embedding matrix be exactly your linear fit matrix multiplied by the probability of each state to output a given token? Could you use that?
Yep, that’s what I was trying to describe as well. Thanks!