What do you mean you’re encoding/decoding like normal but using the k means vectors? Shouldn’t the SAE training process for a top k SAE with k = 1 find these vectors then?
In general I’m a bit skeptical that clustering will work as well on larger models, my impression is that most small models have pretty token level features which might be pretty clusterable with k=1, but for larger models many activations may belong to multiple “clusters”, which you need dictionary learning for.
where the first line is essentially an SAE embedding (and centroids are the features), and the second/third line is a top-k. And for reconstruction do something like
recon = centroids @ latents
which should also be equivalent.
Shouldn’t the SAE training process for a top k SAE with k = 1 find these vectors then?
Yes I would expect an optimal k=1 top-k SAE to find exactly that solution. Confused why k=20 top-k SAEs to so badly then.
If this is a crux then a quick way to prove this would be for me to write down encoder/decoder weights and throw them into a standard SAE code. I haven’t done this yet.
I just tried to replicate this on GPT-2 with expansion factor 4 (so total number of centroids = 768 * 4). I get that clustering recovers ~87% fraction of variance explained, while a k = 32 SAE gets more like 95% variance explained. I did the nonlinear version of finding nearest neighbors when using k means to give k means the biggest advantage possible, and did k-means clustering on points using the FAISS clustering library.
What do you mean you’re encoding/decoding like normal but using the k means vectors? Shouldn’t the SAE training process for a top k SAE with k = 1 find these vectors then?
In general I’m a bit skeptical that clustering will work as well on larger models, my impression is that most small models have pretty token level features which might be pretty clusterable with k=1, but for larger models many activations may belong to multiple “clusters”, which you need dictionary learning for.
So I do something like
where the first line is essentially an SAE embedding (and centroids are the features), and the second/third line is a top-k. And for reconstruction do something like
which should also be equivalent.
Yes I would expect an optimal k=1 top-k SAE to find exactly that solution. Confused why k=20 top-k SAEs to so badly then.
If this is a crux then a quick way to prove this would be for me to write down encoder/decoder weights and throw them into a standard SAE code. I haven’t done this yet.
I just tried to replicate this on GPT-2 with expansion factor 4 (so total number of centroids = 768 * 4). I get that clustering recovers ~87% fraction of variance explained, while a k = 32 SAE gets more like 95% variance explained. I did the nonlinear version of finding nearest neighbors when using k means to give k means the biggest advantage possible, and did k-means clustering on points using the FAISS clustering library.
Definitely take this with a grain of salt, I’m going to look through my code and see if I can reproduce your results on pythia too, and if so try on a larger model to. Code: https://github.com/JoshEngels/CheckClustering/tree/main