Is there a reason you did 300⁄400 randomly sampled indices, instead of evenly spaced indices (e.g. every 1⁄300 of the total training steps)?
Did you subtract the mean of the weights before doing the SVD? Otherwise, the first component is probably the mean of the 300⁄400 weight vectors.
Btw, Neel Nanda has done similar experiments to your SVD experiment on his grokking models. For example, if we sample 400 datapoints from his mainline model, and cocatenate them into a [400, 226825] matrix, it turns out that the singular values are even more extreme than in your case: (apologies for the sloppily made figure)
(By your 10% threshold, it turns out only 5 singular values!)
We’re pretty sure that component 0 is the “memorization” component, as it is dense in the Fourier basis while the other 4 big components are sparse, and then subtracting it from the model leads to better generalization.
Unfortunately, interpreting the other big ones turned out to be pretty non-trivial. This is despite the fact that many parts of the final network have low-rank approximations that capture >99% of the variance, we know the network is getting more sparse in the Fourier basis, and the entire function of the network is well known enough that you can literally read off the trig identities being used at the MLP layer. So I’m not super confident that purely unsupervised linear methods actually will help much with interpretability here.
(Also, having worked with this particular environment,I’ve seen a lot of further evidence that techniques like SVD and PCA are pretty bad for finding interpretable components.)
Another interesting experiment you might want to do is to look at the principal components of loss over the course of training, as they do in the Anthropic Induction Heads paper.
You can also plot the first two principal components of the logits, which in Neel’s case gives a pretty diagram that shows two inflection points at checkpoints 14 and 103, corresponding to the change between the memorization phase (Checkpoint 0->14, steps 0->1.4k) and the circuit formation phase (Checkpoint 14 ->103, steps 1.4k → 10.3k), and between the circuit formation phase and the cleanup phase (103->400, 10.3k->40k), which marks the start of grokking. Again, I’m not sure how illuminating figures like this actually are; it basically just says “something interesting happens around ~1.4k and ~10.3k, which we know from inspection of any of the other metrics (e.g. train/test loss).
(In this case, there’s a very good reason why looking the top 2 principal components of the logits isn’t super illuminating: there’s 1 “memorizing” direction and 5 “generalizing” directions, corresponding to each of the 5 key frequencies, on top of the normal interpretability problems.)
The main thing I want to do now is replicate the results from a particular paper whose name I can’t remember right now, where an RL agent was trained to navigate to a cheese in the top right corner of a maze, apply this method to the training gradients, and see whether we can locate which parameters are responsible for the if bottom_left(), then navigate_to_top_right() cognition, and which are responsible for the if top_right(), then navigate_to_cheese() cognition, which should be determinable by their time-step distribution.
The main thing I want to do now is replicate the results from a particular paper whose name I can’t remember right now, where an RL agent was trained to navigate to a cheese in the top right corner of a maze, apply this method to the training gradients, and see whether we can locate which parameters are responsible for the if bottom_left(), then navigate_to_top_right() cognition, and which are responsible for the if top_right(), then navigate_to_cheese() cognition, which should be determinable by their time-step distribution.
That is, if bottom_left(), then navigate_to_top_right() should be associated with reinforcement events sooner during training rather than later, so the left singular values locating parameters responsible for that computation should have corresponding right singular values with high-in-magnitude numbers in their beginnings and low-in-magnitude numbers in their ends. Similarly, if top_right(), then navigate_to_cheese() should be associated with reinforcement events later during training, so the opposite holds.
Then I want to verify that we have indeed found the right parameters by ablating the model’s tendency to go to the cheese after its reached the top right corner.
It would also be interesting to see whether we can ablate the ability for it to go to the top right corner while keeping the ability to go to the cheese if the cheese is sufficiently close or it is already in the top right corner. However this seems harder, and not as clearly possible given we’ve found the correct parameters.
I might be missing something, but is there a reason you’re doing this via SVD on gradients, instead of SVD on weights?
Is there a reason to do this with SVD at all, instead of mechanistic interp methods like causal scrubbing/causal tracing/path patching or manual inspection of circuits?
Is there a reason you did 300⁄400 randomly sampled indices, instead of evenly spaced indices (e.g. every 1⁄300 of the total training steps)?
No!
Did you subtract the mean of the weights before doing the SVD? Otherwise, the first component is probably the mean of the 300⁄400 weight vectors.
Ah, this is a good idea! I’ll make sure to incorporate it, thanks!
Unfortunately, interpreting the other big ones turned out to be pretty non-trivial. This is despite the fact that many parts of the final network have low-rank approximations that capture >99% of the variance, we know the network is getting more sparse in the Fourier basis, and the entire function of the network is well known enough that you can literally read off the trig identities being used at the MLP layer. So I’m not super confident that purely unsupervised linear methods actually will help much with interpretability here.
Interesting. I’ll be sure to read what he’s written to see if its what I’d do.
Is there a reason you did 300⁄400 randomly sampled indices, instead of evenly spaced indices (e.g. every 1⁄300 of the total training steps)?
Did you subtract the mean of the weights before doing the SVD? Otherwise, the first component is probably the mean of the 300⁄400 weight vectors.
Btw, Neel Nanda has done similar experiments to your SVD experiment on his grokking models. For example, if we sample 400 datapoints from his mainline model, and cocatenate them into a [400, 226825] matrix, it turns out that the singular values are even more extreme than in your case: (apologies for the sloppily made figure)
(By your 10% threshold, it turns out only 5 singular values!)
We’re pretty sure that component 0 is the “memorization” component, as it is dense in the Fourier basis while the other 4 big components are sparse, and then subtracting it from the model leads to better generalization.
Unfortunately, interpreting the other big ones turned out to be pretty non-trivial. This is despite the fact that many parts of the final network have low-rank approximations that capture >99% of the variance, we know the network is getting more sparse in the Fourier basis, and the entire function of the network is well known enough that you can literally read off the trig identities being used at the MLP layer. So I’m not super confident that purely unsupervised linear methods actually will help much with interpretability here.
(Also, having worked with this particular environment,I’ve seen a lot of further evidence that techniques like SVD and PCA are pretty bad for finding interpretable components.)
Another interesting experiment you might want to do is to look at the principal components of loss over the course of training, as they do in the Anthropic Induction Heads paper.
You can also plot the first two principal components of the logits, which in Neel’s case gives a pretty diagram that shows two inflection points at checkpoints 14 and 103, corresponding to the change between the memorization phase (Checkpoint 0->14, steps 0->1.4k) and the circuit formation phase (Checkpoint 14 ->103, steps 1.4k → 10.3k), and between the circuit formation phase and the cleanup phase (103->400, 10.3k->40k), which marks the start of grokking. Again, I’m not sure how illuminating figures like this actually are; it basically just says “something interesting happens around ~1.4k and ~10.3k, which we know from inspection of any of the other metrics (e.g. train/test loss).
(In this case, there’s a very good reason why looking the top 2 principal components of the logits isn’t super illuminating: there’s 1 “memorizing” direction and 5 “generalizing” directions, corresponding to each of the 5 key frequencies, on top of the normal interpretability problems.)
Probably the easiest environment to run this on are the examples from Lauro Langosco’s Goal Misgeneralization paper.
Another thought:
I might be missing something, but is there a reason you’re doing this via SVD on gradients, instead of SVD on weights?
Is there a reason to do this with SVD at all, instead of mechanistic interp methods like causal scrubbing/causal tracing/path patching or manual inspection of circuits?
No!
Ah, this is a good idea! I’ll make sure to incorporate it, thanks!
Interesting. I’ll be sure to read what he’s written to see if its what I’d do.
Thanks for the pointer, and thanks for the overall very helpful comment!