Edited to fix errors pointed out by @JoshEngels and @Adam Karvonen (mainly: different definition for explained variance, details here).
Summary: K-means explains 72 − 87% of the variance in the activations, comparable to vanilla SAEs but less than better SAEs. I think this (bug-fixed) result is neither evidence in favour of SAEs nor against; the Clustering & SAE numbers make a straight-ish line on a log plot.
Epistemic status: This is a weekend-experiment I ran a while ago and I figured I should write it up to share. I have taken decent care to check my code for silly mistakes and “shooting myself in the foot”, but these results are not vetted to the standard of a top-level post / paper.
SAEs explain most of the variance in activations. Is this alone a sign that activations are structured in an SAE-friendly way, i.e. that activations are indeed a composition of sparse features like the superposition hypothesis suggests?
I’m asking myself this questions since I initially considered this as pretty solid evidence: SAEs do a pretty impressive job compressing 512 dimensions into ~100 latents, this ought to mean something, right?
But maybe all SAEs are doing is “dataset clustering” (the data is cluster-y and SAEs exploit this)---then a different sensible clustering method should also be able do perform similarly well!
I took this[1] SAE graph from Neuronpedia, and added a K-means clustering baseline. Think of this as pretty equivalent to a top-k SAE (with k=1; in fact I added a point where I use the K-means centroids as features of a top-1 SAE which does slightly better than vanilla K-means with binary latents).
K-means clustering (which uses a single latent, L0=1) explains 72 − 87% of the variance. This is a good number to keep in mind when comparing to SAEs. However, this is significantly lower than SAEs (which often achieve 90%+). To have a comparison using more latents I’m adding a PCA + Clustering baseline where I apply a PCA before doing the clustering. It does roughly as well as vanilla SAEs. The SAEBench upcoming paper also does a PCA baseline so I won’t discuss PCA in detail here.
Here’s the result for layers 3 and 4, and 4k and 16k latents. (These were the 4 SAEBench suites available on Neuronpedia.) There’s two points each for the clustering results corresponding to 100k and 1M training samples. Code here.
What about interpretability? Clusters seem “monosemantic” on a skim. In an informal investigation I looked at max-activating dataset examples, and they seem to correspond to related contexts / words like monosemantic SAE features tend to do. I haven’t spent much time looking into this though.
Both my code and SAEBench/Neuronpedia use OpenWebText with 128 tokens context length. After the edit I’ve made sure to use the same Variance Explained definition for all points.
A final caveat I want to mention is that I think the SAEs I’m comparing here (SAEBench suite for Pythia-70M) are maybe weak. They’re only using 4k and 16k latents, for 512 embedding dimensions, using expansion ratios of 8 and 32, respectively (the best SAEs I could find for a ~100M model). But I also limit the number of clusters to the same numbers, so I don’t necessarily expect the balance to change qualitatively at higher expansion ratios.
I want to thank @Adam Karvonen, @Lucius Bushnaq, @jake_mendel, and @Patrick Leask for feedback on early results, and @Johnny Lin for implementing an export feature on Neuronpedia for me! I also learned that @scasper proposed something similar here (though I didn’t know about it), I’m excited for follow-ups implementing some of Stephen’s advanced ideas (HAC, a probabilistic alg, …).
I’m using the conventional definition of variance explained, rather than the one used by Neuronpedia, thus the numbers are slightly different. I’ll include the alternative graph in a comment.
I was having trouble reproducing your results on Pythia, and was only able to get 60% variance explained. I may have tracked it down: I think you may be computing FVU incorrectly.
When I used your FVU implementation, I got 72% variance explained; this is still less than you, but much closer, so I think this might be causing the improvement over the SAEBench numbers.
In general I think SAEs with low k should be at least as good as k means clustering, and if it’s not I’m a little bit suspicious (when I tried this first on GPT-2 it seemed that a TopK SAE trained with k = 4 did about as well as k means clustering with the nonlinear argmax encoder).
After adding the mean subtraction, the numbers haven’t changed too much actually—but let me make sure I’m using the correct calculation. I’m gonna follow your and @Adam Karvonen’s suggestion of using the SAE bench code and loading my clustering solution as an SAE (this code).
These logs show numbers with the original / corrected explained variance computation; the difference is in the 3-8% range.
I feel like my post appears overly dramatic; I’m not very surprised and don’t consider this the strongest evidence against SAEs. It’s an experiment I ran a while ago and it hasn’t changed my (somewhat SAE-sceptic) stance much.
But this is me having seen a bunch of other weird SAE behaviours (pre-activation distributions are not the way you’d expect from the superposition hypothesis h/t @jake_mendel, if you feed SAE-reconstructed activations back into the encoder the SAE goes nuts, stuff mentioned in recent Apollo papers, …).
Reasons this could be less concerning that it looks
Activation reconstruction isn’t that important: Clustering is a strong optimiser—if you fill a space with 16k clusters maybe 90% reconstruction isn’t that surprising. I should really run a random Gaussian data baseline for this.
End-to-end loss is more important, and maybe SAEs perform much better when you consider end-to-end reconstruction loss.
This isn’t the only evidence in favour of SAEs, they also kinda work for steering/probing (thoughprettybadly).
I should really run a random Gaussian data baseline for this.
Tentatively I get similar results (70-85% variance explained) for random data—I haven’t checked that code at all though, don’t trust this. Will double check this tomorrow.
(In that case SAE’s performance would also be unsurprising I suppose)
I’m not sure what you mean by “K-means clustering baseline (with K=1)”. I would think the K in K-means stands for the number of means you use, so with K=1, you’re just taking the mean direction of the weights. I would expect this to explain maybe 50% of the variance (or less), not 90% of the variance.
If we imagine that the meaning is given not by the dimensions of the space but rather by regions/points/volumes of the space
I think this is what I care about finding out. If you’re right this is indeed not surprising nor an issue, but you being right would be a major departure from the current mainstream interpretability paradigm(?).
The question of regions vs compositionality is what I’ve been investigating with my mentees recently, and pretty keen on. I’ll want to write up my current thoughts on this topic sometime soon.
I’m not sure what you mean by “K-means clustering baseline (with K=1)”. I would think the K in K-means stands for the number of means you use, so with K=1, you’re just taking the mean direction of the weights. I would expect this to explain maybe 50% of the variance (or less), not 90% of the variance.
Thanks for pointing this out! I confused nomenclature, will fix!
Edit: Fixed now. I confused
the number of clusters (“K”) / dictionary size
the number of latents (“L_0” or k in top-k SAEs). Some clustering methods allow you to assign multiple clusters to one point, so effectively you get a “L_0>1″ but normal KMeans is only 1 cluster per point. I confused the K of KMeans and the k (aka L_0) of top-k SAEs.
I think he messed up the lingo a bit, but looking at the code he seems to have done k-means with a number of clusters similar to the number of SAE latents, which seems fine.
I’m going to update the results in the top-level comment with the corrected data; I’m pasting the original figures here for posterity / understanding the past discussion. Summary of changes:
[Minor] I didn’t subtract the mean in the variance calculation. This barely had an effect on the results.
[Major] I used a different definition of “Explained Variance” which caused a pretty large difference
Old (no longer true) text:
It turns out that even clustering (essentially L_0=1) explains up to 90% of the variance in activations, being matched only by SAEs with L_0>100. This isn’t an entirely fair comparison, since SAEs are optimised for the large-L_0 regime, while I haven’t found a L_0>1 operationalisation of clustering that meaningfully improves over L_0=1. To have some comparison I’m adding a PCA + Clustering baseline where I apply a PCA before doing the clustering. It does roughly as well as expected, exceeding the SAE reconstruction for most L0 values. The SAEBench upcoming paper also does a PCA baseline so I won’t discuss PCA in detail here.
[...]
Here’s the code used to get the clustering & PCA below; the SAE numbers are taken straight from Neuronpedia. Both my code and SAEBench/Neuronpedia use OpenWebText with 128 tokens context length so I hope the numbers are comparable, but there’s a risk I missed something and we’re comparing apples to oranges.
I think the relation between K-means and sparse dictionary learning (essentially K-means is equivalent to an L_0=1 constraint) is already well-known in the sparse coding literature? For example see this wiki article on K-SVD (a sparse dictionary learning algorithm) which first reviews this connection before getting into the nuances of k-SVD.
Were the SAEs for this comparison trained on multiple passes through the data, or just one pass/epoch? Because if for K-means you did multiple passes through the data but for SAEs just one then this feels like an unfair comparison.
What do you mean you’re encoding/decoding like normal but using the k means vectors? Shouldn’t the SAE training process for a top k SAE with k = 1 find these vectors then?
In general I’m a bit skeptical that clustering will work as well on larger models, my impression is that most small models have pretty token level features which might be pretty clusterable with k=1, but for larger models many activations may belong to multiple “clusters”, which you need dictionary learning for.
where the first line is essentially an SAE embedding (and centroids are the features), and the second/third line is a top-k. And for reconstruction do something like
recon = centroids @ latents
which should also be equivalent.
Shouldn’t the SAE training process for a top k SAE with k = 1 find these vectors then?
Yes I would expect an optimal k=1 top-k SAE to find exactly that solution. Confused why k=20 top-k SAEs to so badly then.
If this is a crux then a quick way to prove this would be for me to write down encoder/decoder weights and throw them into a standard SAE code. I haven’t done this yet.
I just tried to replicate this on GPT-2 with expansion factor 4 (so total number of centroids = 768 * 4). I get that clustering recovers ~87% fraction of variance explained, while a k = 32 SAE gets more like 95% variance explained. I did the nonlinear version of finding nearest neighbors when using k means to give k means the biggest advantage possible, and did k-means clustering on points using the FAISS clustering library.
Edited to fix errors pointed out by @JoshEngels and @Adam Karvonen (mainly: different definition for explained variance, details here).
Summary: K-means explains 72 − 87% of the variance in the activations, comparable to vanilla SAEs but less than better SAEs. I think this (bug-fixed) result is neither evidence in favour of SAEs nor against; the Clustering & SAE numbers make a straight-ish line on a log plot.
Epistemic status: This is a weekend-experiment I ran a while ago and I figured I should write it up to share. I have taken decent care to check my code for silly mistakes and “shooting myself in the foot”, but these results are not vetted to the standard of a top-level post / paper.
SAEs explain most of the variance in activations. Is this alone a sign that activations are structured in an SAE-friendly way, i.e. that activations are indeed a composition of sparse features like the superposition hypothesis suggests?
I’m asking myself this questions since I initially considered this as pretty solid evidence: SAEs do a pretty impressive job compressing 512 dimensions into ~100 latents, this ought to mean something, right?
But maybe all SAEs are doing is “dataset clustering” (the data is cluster-y and SAEs exploit this)---then a different sensible clustering method should also be able do perform similarly well!
I took this[1] SAE graph from Neuronpedia, and added a K-means clustering baseline. Think of this as pretty equivalent to a top-k SAE (with k=1; in fact I added a point where I use the K-means centroids as features of a top-1 SAE which does slightly better than vanilla K-means with binary latents).
K-means clustering (which uses a single latent, L0=1) explains 72 − 87% of the variance. This is a good number to keep in mind when comparing to SAEs. However, this is significantly lower than SAEs (which often achieve 90%+). To have a comparison using more latents I’m adding a PCA + Clustering baseline where I apply a PCA before doing the clustering. It does roughly as well as vanilla SAEs. The SAEBench upcoming paper also does a PCA baseline so I won’t discuss PCA in detail here.
Here’s the result for layers 3 and 4, and 4k and 16k latents. (These were the 4 SAEBench suites available on Neuronpedia.) There’s two points each for the clustering results corresponding to 100k and 1M training samples. Code here.
What about interpretability? Clusters seem “monosemantic” on a skim. In an informal investigation I looked at max-activating dataset examples, and they seem to correspond to related contexts / words like monosemantic SAE features tend to do. I haven’t spent much time looking into this though.
Both my code and SAEBench/Neuronpedia use OpenWebText with 128 tokens context length. After the edit I’ve made sure to use the same Variance Explained definition for all points.
A final caveat I want to mention is that I think the SAEs I’m comparing here (SAEBench suite for Pythia-70M) are maybe weak. They’re only using 4k and 16k latents, for 512 embedding dimensions, using expansion ratios of 8 and 32, respectively (the best SAEs I could find for a ~100M model). But I also limit the number of clusters to the same numbers, so I don’t necessarily expect the balance to change qualitatively at higher expansion ratios.
I want to thank @Adam Karvonen, @Lucius Bushnaq, @jake_mendel, and @Patrick Leask for feedback on early results, and @Johnny Lin for implementing an export feature on Neuronpedia for me! I also learned that @scasper proposed something similar here (though I didn’t know about it), I’m excited for follow-ups implementing some of Stephen’s advanced ideas (HAC, a probabilistic alg, …).
I’m using the conventional definition of variance explained, rather than the one used by Neuronpedia, thus the numbers are slightly different. I’ll include the alternative graph in a comment.
I was having trouble reproducing your results on Pythia, and was only able to get 60% variance explained. I may have tracked it down: I think you may be computing FVU incorrectly.
https://gist.github.com/Stefan-Heimersheim/ff1d3b92add92a29602b411b9cd76cec#file-clustering_pythia-py-L309
I think FVU is correctly computed by subtracting the mean from each dimension when computing the denominator. See the SAEBench impl here:
https://github.com/adamkarvonen/SAEBench/blob/5204b4822c66a838d9c9221640308e7c23eda00a/sae_bench/evals/core/main.py#L566
When I used your FVU implementation, I got 72% variance explained; this is still less than you, but much closer, so I think this might be causing the improvement over the SAEBench numbers.
In general I think SAEs with low k should be at least as good as k means clustering, and if it’s not I’m a little bit suspicious (when I tried this first on GPT-2 it seemed that a TopK SAE trained with k = 4 did about as well as k means clustering with the nonlinear argmax encoder).
Here’s my clustering code: https://github.com/JoshEngels/CheckClustering/blob/main/clustering.py
You’re right. I forgot subtracting the mean. Thanks a lot!!
I’m computing new numbers now,
but indeed I expect this to explain my result!(Edit: Seems to not change too much)After adding the mean subtraction, the numbers haven’t changed too much actually—but let me make sure I’m using the correct calculation. I’m gonna follow your and @Adam Karvonen’s suggestion of using the SAE bench code and loading my clustering solution as an SAE (this code).
These logs show numbers with the original / corrected explained variance computation; the difference is in the 3-8% range.
this seems concerning. Can somebody ELI5 what’s going on here?
I feel like my post appears overly dramatic; I’m not very surprised and don’t consider this the strongest evidence against SAEs. It’s an experiment I ran a while ago and it hasn’t changed my (somewhat SAE-sceptic) stance much.
But this is me having seen a bunch of other weird SAE behaviours (pre-activation distributions are not the way you’d expect from the superposition hypothesis h/t @jake_mendel, if you feed SAE-reconstructed activations back into the encoder the SAE goes nuts, stuff mentioned in recent Apollo papers, …).
Reasons this could be less concerning that it looks
Activation reconstruction isn’t that important: Clustering is a strong optimiser—if you fill a space with 16k clusters maybe 90% reconstruction isn’t that surprising. I should really run a random Gaussian data baseline for this.
End-to-end loss is more important, and maybe SAEs perform much better when you consider end-to-end reconstruction loss.
This isn’t the only evidence in favour of SAEs, they also kinda work for steering/probing (though pretty badly).
Tentatively I get similar results (70-85% variance explained) for random data—I haven’t checked that code at all though, don’t trust this. Will double check this tomorrow.(In that case SAE’s performance would also be unsurprising I suppose)Is there a benchmark in which SAEs clearly, definitely outperform standard techniques?
I’m not sure what you mean by “K-means clustering baseline (with K=1)”. I would think the K in K-means stands for the number of means you use, so with K=1, you’re just taking the mean direction of the weights. I would expect this to explain maybe 50% of the variance (or less), not 90% of the variance.
But anyway, under my current model (roughly Why I’m bearish on mechanistic interpretability: the shards are not in the network + Binary encoding as a simple explicit construction for superposition) it seems about as natural to use K-means as it does to use SAEs, and not necessarily an issue if K-means outperforms SAEs. If we imagine that the meaning is given not by the dimensions of the space but rather by regions/points/volumes of the space, then K-means seems like a perfectly cromulent quantization for identifying these volumes. The major issue is where we go from here.
I think this is what I care about finding out. If you’re right this is indeed not surprising nor an issue, but you being right would be a major departure from the current mainstream interpretability paradigm(?).
The question of regions vs compositionality is what I’ve been investigating with my mentees recently, and pretty keen on. I’ll want to write up my current thoughts on this topic sometime soon.
Thanks for pointing this out! I confused nomenclature, will fix!
Edit: Fixed now. I confused
the number of clusters (“K”) / dictionary size
the number of latents (“L_0” or k in top-k SAEs). Some clustering methods allow you to assign multiple clusters to one point, so effectively you get a “L_0>1″ but normal KMeans is only 1 cluster per point. I confused the K of KMeans and the k (aka L_0) of top-k SAEs.
I think he messed up the lingo a bit, but looking at the code he seems to have done k-means with a number of clusters similar to the number of SAE latents, which seems fine.
Same plot but using SAEBench’s FVU definition. Matches this Neuronpedia page.
I’m going to update the results in the top-level comment with the corrected data; I’m pasting the original figures here for posterity / understanding the past discussion. Summary of changes:
[Minor] I didn’t subtract the mean in the variance calculation. This barely had an effect on the results.
[Major] I used a different definition of “Explained Variance” which caused a pretty large difference
Old (no longer true) text:
I think the relation between K-means and sparse dictionary learning (essentially K-means is equivalent to an L_0=1 constraint) is already well-known in the sparse coding literature? For example see this wiki article on K-SVD (a sparse dictionary learning algorithm) which first reviews this connection before getting into the nuances of k-SVD.
Were the SAEs for this comparison trained on multiple passes through the data, or just one pass/epoch? Because if for K-means you did multiple passes through the data but for SAEs just one then this feels like an unfair comparison.
What do you mean you’re encoding/decoding like normal but using the k means vectors? Shouldn’t the SAE training process for a top k SAE with k = 1 find these vectors then?
In general I’m a bit skeptical that clustering will work as well on larger models, my impression is that most small models have pretty token level features which might be pretty clusterable with k=1, but for larger models many activations may belong to multiple “clusters”, which you need dictionary learning for.
So I do something like
where the first line is essentially an SAE embedding (and centroids are the features), and the second/third line is a top-k. And for reconstruction do something like
which should also be equivalent.
Yes I would expect an optimal k=1 top-k SAE to find exactly that solution. Confused why k=20 top-k SAEs to so badly then.
If this is a crux then a quick way to prove this would be for me to write down encoder/decoder weights and throw them into a standard SAE code. I haven’t done this yet.
I just tried to replicate this on GPT-2 with expansion factor 4 (so total number of centroids = 768 * 4). I get that clustering recovers ~87% fraction of variance explained, while a k = 32 SAE gets more like 95% variance explained. I did the nonlinear version of finding nearest neighbors when using k means to give k means the biggest advantage possible, and did k-means clustering on points using the FAISS clustering library.
Definitely take this with a grain of salt, I’m going to look through my code and see if I can reproduce your results on pythia too, and if so try on a larger model to. Code: https://github.com/JoshEngels/CheckClustering/tree/main