Stefan Heimersheim. Research Scientist at Apollo Research, Mechanistic Interpretability. The opinions expressed here are my own and do not necessarily reflect the views of my employer.
StefanHex
Yep, that’s the generalisation that would make most sense
The previous lines calculate the ratio (or 1-ratio) stored in the “explained variance” key for every sample/batch. Then in that later quoted line, the list is averaged, I.e. we”re taking the sample average over the ratio. That’s the FVU_B formula.
Let me know if this clears it up or if we’re misunderstanding each other!
Oops, fixed!
I think this is the sum over the vector dimension, but not over the samples. The sum (mean) over samples is taken later in this line which happens after the division
metrics[f"{metric_name}"] = torch.cat(metric_values).mean().item()
Edit: And to clarify, my impression is that people think of this as alternative definitions of FVU and you got to pick one, rather than one being right and one being a bug.
Edit2: And I’m in touch with the SAEBench authors about making a PR to change this / add both options (and by extension probably doing the same in SAELens); though I won’t mind if anyone else does it!
PSA: People use different definitions of “explained variance” / “fraction of variance unexplained” (FVU)
is the formula I think is sensible; the bottom is simply the variance of the data, and the top is the variance of the residuals. The indicates the norm over the dimension of the vector . I believe it matches Wikipedia’s definition of FVU and R squared.
is the formula used by SAELens and SAEBench. It seems less principled, @Lucius Bushnaq and I couldn’t think of a nice quantity it corresponds to. I think of it as giving more weight to samples that are close to the mean, kind-of averaging relative reduction in difference rather than absolute.
A third version (h/t @JoshEngels) which computes the FVU for each dimension independently and then averages, but that version is not used in the context we’re discussing here.
In my recent comment I had computed my own , and compared it to FVUs from SAEBench (which used ) and obtained nonsense results.
Curiously the two definitions seem to be approximately proportional—below I show the performance of a bunch of SAEs—though for different distributions (here: activations in layer 3 and 4) the ratio differs.[1] Still, this means using instead of to compare e.g. different SAEs doesn’t make a big difference as long as one is consistent.
Thanks to @JoshEngels for pointing out the difference, and to @Lucius Bushnaq for helpful discussions.
- ^
If a predictor doesn’t perform systematically better or worse at points closer to the mean then this makes sense. The denominator changes the relative weight of different samples but this doesn’t have any effect beyond noise and a global scale, as long as there is no systematic performance difference.
- ^
Same plot but using SAEBench’s FVU definition. Matches this Neuronpedia page.
I’m going to update the results in the top-level comment with the corrected data; I’m pasting the original figures here for posterity / understanding the past discussion. Summary of changes:
[Minor] I didn’t subtract the mean in the variance calculation. This barely had an effect on the results.
[Major] I used a different definition of “Explained Variance” which caused a pretty large difference
Old (no longer true) text:
It turns out that even clustering (essentially L_0=1) explains up to 90% of the variance in activations, being matched only by SAEs with L_0>100. This isn’t an entirely fair comparison, since SAEs are optimised for the large-L_0 regime, while I haven’t found a L_0>1 operationalisation of clustering that meaningfully improves over L_0=1. To have some comparison I’m adding a PCA + Clustering baseline where I apply a PCA before doing the clustering. It does roughly as well as expected, exceeding the SAE reconstruction for most L0 values. The SAEBench upcoming paper also does a PCA baseline so I won’t discuss PCA in detail here.
[...]Here’s the code used to get the clustering & PCA below; the SAE numbers are taken straight from Neuronpedia. Both my code and SAEBench/Neuronpedia use OpenWebText with 128 tokens context length so I hope the numbers are comparable, but there’s a risk I missed something and we’re comparing apples to oranges.
After adding the mean subtraction, the numbers haven’t changed too much actually—but let me make sure I’m using the correct calculation. I’m gonna follow your and @Adam Karvonen’s suggestion of using the SAE bench code and loading my clustering solution as an SAE (this code).
These logs show numbers with the original / corrected explained variance computation; the difference is in the 3-8% range.
v3 (KMeans): Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=4096, variance explained = 0.8887 / 0.8568 v3 (KMeans): Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16384, variance explained = 0.9020 / 0.8740 v3 (KMeans): Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=4096, variance explained = 0.8044 / 0.7197 v3 (KMeans): Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16384, variance explained = 0.8261 / 0.7509 PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=4095, n_pca=1, variance explained = 0.8910 / 0.8599 PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16383, n_pca=1, variance explained = 0.9041 / 0.8766 PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=4094, n_pca=2, variance explained = 0.8948 / 0.8647 PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16382, n_pca=2, variance explained = 0.9076 / 0.8812 PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=4091, n_pca=5, variance explained = 0.9044 / 0.8770 PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16379, n_pca=5, variance explained = 0.9159 / 0.8919 PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=4086, n_pca=10, variance explained = 0.9121 / 0.8870 PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16374, n_pca=10, variance explained = 0.9232 / 0.9012 PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=4076, n_pca=20, variance explained = 0.9209 / 0.8983 PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16364, n_pca=20, variance explained = 0.9314 / 0.9118 PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=4046, n_pca=50, variance explained = 0.9379 / 0.9202 PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16334, n_pca=50, variance explained = 0.9468 / 0.9315 PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=3996, n_pca=100, variance explained = 0.9539 / 0.9407 PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16284, n_pca=100, variance explained = 0.9611 / 0.9499 PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=3896, n_pca=200, variance explained = 0.9721 / 0.9641 PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16184, n_pca=200, variance explained = 0.9768 / 0.9702 PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=3596, n_pca=500, variance explained = 0.9999 / 0.9998 PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=15884, n_pca=500, variance explained = 0.9999 / 0.9999 PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=4095, n_pca=1, variance explained = 0.8077 / 0.7245 PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16383, n_pca=1, variance explained = 0.8292 / 0.7554 PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=4094, n_pca=2, variance explained = 0.8145 / 0.7342 PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16382, n_pca=2, variance explained = 0.8350 / 0.7636 PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=4091, n_pca=5, variance explained = 0.8244 / 0.7484 PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16379, n_pca=5, variance explained = 0.8441 / 0.7767 PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=4086, n_pca=10, variance explained = 0.8326 / 0.7602 PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16374, n_pca=10, variance explained = 0.8516 / 0.7875 PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=4076, n_pca=20, variance explained = 0.8460 / 0.7794 PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16364, n_pca=20, variance explained = 0.8637 / 0.8048 PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=4046, n_pca=50, variance explained = 0.8735 / 0.8188 PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16334, n_pca=50, variance explained = 0.8884 / 0.8401 PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=3996, n_pca=100, variance explained = 0.9021 / 0.8598 PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16284, n_pca=100, variance explained = 0.9138 / 0.8765 PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=3896, n_pca=200, variance explained = 0.9399 / 0.9139 PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16184, n_pca=200, variance explained = 0.9473 / 0.9246 PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=3596, n_pca=500, variance explained = 0.9997 / 0.9996 PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=15884, n_pca=500, variance explained = 0.9998 / 0.9997
You’re right. I forgot subtracting the mean. Thanks a lot!!
I’m computing new numbers now,
but indeed I expect this to explain my result!(Edit: Seems to not change too much)
I should really run a random Gaussian data baseline for this.
Tentatively I get similar results (70-85% variance explained) for random data—I haven’t checked that code at all though, don’t trust this. Will double check this tomorrow.(In that case SAE’s performance would also be unsurprising I suppose)
If we imagine that the meaning is given not by the dimensions of the space but rather by regions/points/volumes of the space
I think this is what I care about finding out. If you’re right this is indeed not surprising nor an issue, but you being right would be a major departure from the current mainstream interpretability paradigm(?).
The question of regions vs compositionality is what I’ve been investigating with my mentees recently, and pretty keen on. I’ll want to write up my current thoughts on this topic sometime soon.
What do you mean you’re encoding/decoding like normal but using the k means vectors?
So I do something like
latents_tmp = torch.einsum("bd,nd->bn", data, centroids) max_latent = latents_tmp.argmax(dim=-1) # shape: [batch] latents = one_hot(max_latent)
where the first line is essentially an SAE embedding (and centroids are the features), and the second/third line is a top-k. And for reconstruction do something like
recon = centroids @ latents
which should also be equivalent.
Shouldn’t the SAE training process for a top k SAE with k = 1 find these vectors then?
Yes I would expect an optimal k=1 top-k SAE to find exactly that solution. Confused why k=20 top-k SAEs to so badly then.
If this is a crux then a quick way to prove this would be for me to write down encoder/decoder weights and throw them into a standard SAE code. I haven’t done this yet.
I’m not sure what you mean by “K-means clustering baseline (with K=1)”. I would think the K in K-means stands for the number of means you use, so with K=1, you’re just taking the mean direction of the weights. I would expect this to explain maybe 50% of the variance (or less), not 90% of the variance.
Thanks for pointing this out! I confused nomenclature, will fix!
Edit: Fixed now. I confused
the number of clusters (“K”) / dictionary size
the number of latents (“L_0” or k in top-k SAEs). Some clustering methods allow you to assign multiple clusters to one point, so effectively you get a “L_0>1″ but normal KMeans is only 1 cluster per point. I confused the K of KMeans and the k (aka L_0) of top-k SAEs.
this seems concerning.
I feel like my post appears overly dramatic; I’m not very surprised and don’t consider this the strongest evidence against SAEs. It’s an experiment I ran a while ago and it hasn’t changed my (somewhat SAE-sceptic) stance much.
But this is me having seen a bunch of other weird SAE behaviours (pre-activation distributions are not the way you’d expect from the superposition hypothesis h/t @jake_mendel, if you feed SAE-reconstructed activations back into the encoder the SAE goes nuts, stuff mentioned in recent Apollo papers, …).
Reasons this could be less concerning that it looks
Activation reconstruction isn’t that important: Clustering is a strong optimiser—if you fill a space with 16k clusters maybe 90% reconstruction isn’t that surprising. I should really run a random Gaussian data baseline for this.
End-to-end loss is more important, and maybe SAEs perform much better when you consider end-to-end reconstruction loss.
This isn’t the only evidence in favour of SAEs, they also kinda work for steering/probing (though pretty badly).
Edited to fix errors pointed out by @JoshEngels and @Adam Karvonen (mainly: different definition for explained variance, details here).
Summary: K-means explains 72 − 87% of the variance in the activations, comparable to vanilla SAEs but less than better SAEs. I think this (bug-fixed) result is neither evidence in favour of SAEs nor against; the Clustering & SAE numbers make a straight-ish line on a log plot.
Epistemic status: This is a weekend-experiment I ran a while ago and I figured I should write it up to share. I have taken decent care to check my code for silly mistakes and “shooting myself in the foot”, but these results are not vetted to the standard of a top-level post / paper.
SAEs explain most of the variance in activations. Is this alone a sign that activations are structured in an SAE-friendly way, i.e. that activations are indeed a composition of sparse features like the superposition hypothesis suggests?
I’m asking myself this questions since I initially considered this as pretty solid evidence: SAEs do a pretty impressive job compressing 512 dimensions into ~100 latents, this ought to mean something, right?
But maybe all SAEs are doing is “dataset clustering” (the data is cluster-y and SAEs exploit this)---then a different sensible clustering method should also be able do perform similarly well!
I took this[1] SAE graph from Neuronpedia, and added a K-means clustering baseline. Think of this as pretty equivalent to a top-k SAE (with k=1; in fact I added a point where I use the K-means centroids as features of a top-1 SAE which does slightly better than vanilla K-means with binary latents).
K-means clustering (which uses a single latent, L0=1) explains 72 − 87% of the variance. This is a good number to keep in mind when comparing to SAEs. However, this is significantly lower than SAEs (which often achieve 90%+). To have a comparison using more latents I’m adding a PCA + Clustering baseline where I apply a PCA before doing the clustering. It does roughly as well as vanilla SAEs. The SAEBench upcoming paper also does a PCA baseline so I won’t discuss PCA in detail here.
Here’s the result for layers 3 and 4, and 4k and 16k latents. (These were the 4 SAEBench suites available on Neuronpedia.) There’s two points each for the clustering results corresponding to 100k and 1M training samples. Code here.
What about interpretability? Clusters seem “monosemantic” on a skim. In an informal investigation I looked at max-activating dataset examples, and they seem to correspond to related contexts / words like monosemantic SAE features tend to do. I haven’t spent much time looking into this though.
Both my code and SAEBench/Neuronpedia use OpenWebText with 128 tokens context length. After the edit I’ve made sure to use the same Variance Explained definition for all points.
A final caveat I want to mention is that I think the SAEs I’m comparing here (SAEBench suite for Pythia-70M) are maybe weak. They’re only using 4k and 16k latents, for 512 embedding dimensions, using expansion ratios of 8 and 32, respectively (the best SAEs I could find for a ~100M model). But I also limit the number of clusters to the same numbers, so I don’t necessarily expect the balance to change qualitatively at higher expansion ratios.
I want to thank @Adam Karvonen, @Lucius Bushnaq, @jake_mendel, and @Patrick Leask for feedback on early results, and @Johnny Lin for implementing an export feature on Neuronpedia for me! I also learned that @scasper proposed something similar here (though I didn’t know about it), I’m excited for follow-ups implementing some of Stephen’s advanced ideas (HAC, a probabilistic alg, …).
- ^
I’m using the conventional definition of variance explained, rather than the one used by Neuronpedia, thus the numbers are slightly different. I’ll include the alternative graph in a comment.
- ^
Detecting Strategic Deception Using Linear Probes
I’ve just read the article, and found it indeed very thought provoking, and I will be thinking more about it in the days to come.
One thing though I kept thinking: Why doesn’t the article mention AI Safety research much?
In the passage
The only policy that AI Doomers mostly agree on is that AI development should be slowed down somehow, in order to “buy time.”
I was thinking: surely most people would agree on policies like “Do more research into AI alignment” / “Spend more money on AI Notkilleveryoneism research”?
In general the article frames the policy to “buy time” as to wait for more competent governments or humans, while I find it plausible that progress in AI alignment research could outweigh that effect.
—
I suppose the article is primarily concerned with AGI and ASI, and in that matter I see much less research progress than in more prosaic fields.
That being said, I believe that research into questions like “When do Chatbots scheme?”, “Do models have internal goals?”, “How can we understand the computation inside a neural network?” will make us less likely to die in the next decades.
Then, current rationalist / EA policy goals (including but lot limited to pauses and slow downs of capabilities research) could have a positive impact via the “do more (selective) research” path as well.
Yeah you probably shouldn’t concat the spaces due to things like “they might have very different norms & baseline variances”. Maybe calculate each layer separately, then if they’re all similar average them together, otherwise keep separate and quote as separate numbers in your results