I was having trouble reproducing your results on Pythia, and was only able to get 60% variance explained. I may have tracked it down: I think you may be computing FVU incorrectly.
When I used your FVU implementation, I got 72% variance explained; this is still less than you, but much closer, so I think this might be causing the improvement over the SAEBench numbers.
In general I think SAEs with low k should be at least as good as k means clustering, and if it’s not I’m a little bit suspicious (when I tried this first on GPT-2 it seemed that a TopK SAE trained with k = 4 did about as well as k means clustering with the nonlinear argmax encoder).
After adding the mean subtraction, the numbers haven’t changed too much actually—but let me make sure I’m using the correct calculation. I’m gonna follow your and @Adam Karvonen’s suggestion of using the SAE bench code and loading my clustering solution as an SAE (this code).
These logs show numbers with the original / corrected explained variance computation; the difference is in the 3-8% range.
I was having trouble reproducing your results on Pythia, and was only able to get 60% variance explained. I may have tracked it down: I think you may be computing FVU incorrectly.
https://gist.github.com/Stefan-Heimersheim/ff1d3b92add92a29602b411b9cd76cec#file-clustering_pythia-py-L309
I think FVU is correctly computed by subtracting the mean from each dimension when computing the denominator. See the SAEBench impl here:
https://github.com/adamkarvonen/SAEBench/blob/5204b4822c66a838d9c9221640308e7c23eda00a/sae_bench/evals/core/main.py#L566
When I used your FVU implementation, I got 72% variance explained; this is still less than you, but much closer, so I think this might be causing the improvement over the SAEBench numbers.
In general I think SAEs with low k should be at least as good as k means clustering, and if it’s not I’m a little bit suspicious (when I tried this first on GPT-2 it seemed that a TopK SAE trained with k = 4 did about as well as k means clustering with the nonlinear argmax encoder).
Here’s my clustering code: https://github.com/JoshEngels/CheckClustering/blob/main/clustering.py
You’re right. I forgot subtracting the mean. Thanks a lot!!
I’m computing new numbers now,
but indeed I expect this to explain my result!(Edit: Seems to not change too much)After adding the mean subtraction, the numbers haven’t changed too much actually—but let me make sure I’m using the correct calculation. I’m gonna follow your and @Adam Karvonen’s suggestion of using the SAE bench code and loading my clustering solution as an SAE (this code).
These logs show numbers with the original / corrected explained variance computation; the difference is in the 3-8% range.