All current SAEs I’m aware of seem to score very badly on reconstructing the original model’s activations.
If you insert a current SOTA SAE into a language model’s residual stream, model performance on next token prediction will usually degrade down to what a model trained with less than a tenth or a hundredth of the original model’s compute would get. (This is based on extrapolating with Chinchilla scaling curves at optimal compute). And that’s for inserting one SAE at one layer. If you want to study circuits of SAE features, you’ll have to insert SAEs in multiple layers at the same time, potentially further degrading performance.
I think many people outside of interp don’t realize this. Part of the reason they don’t realize it might be that almost all SAE papers report loss reconstruction scores on a linear scale, rather than on a log scale or an LM scaling curve. Going from 1.5 CE loss to 2.0 CE loss is a lot worse than going from 4.5 CE to 5.0 CE. Under the hypothesis that the SAE is capturing some of the model’s ‘features’ and failing to capture others, capturing only 50% or 10% of the features might still only drop the CE loss by a small fraction of a unit.
So, if someone is just glancing at the graphs without looking up what the metrics actually mean, they can be left with the impression that performance is much better than it actually is. The two most common metrics I see are raw CE scores of the model with the SAE inserted, and ‘loss recovered’. I think both of these metrics give a wrong sense of scale. ‘Loss recovered’ is the worse offender, because it makes it outright impossible to tell how good the reconstruction really is without additional information. You need to know what the original model’s loss was and what zero baseline they used to do the conversion. Papers don’t always report this, and the numbers can be cumbersome to find even when they do.
I don’t know what an actually good way to measure model performance drop from SAE insertion is. The best I’ve got is to use scaling curves to guess how much compute you’d need to train a model that gets comparable loss, as suggested here. Or maybe alternatively, training with the same number of tokens as the original model, how many parameters you’d need to get comparable loss. Using this measure, the best reported reconstruction score I’m aware of is 0.1 of the original model’s performance, reached by OpenAI’s GPT-4 SAE with 16 million dictionary elements in this paper.
For most papers, I found it hard to convert their SAE reconstruction scores into this format. So I can’t completely exclude the possibility that some other SAE scores much better. But at this point, I’d be quite surprised if anyone had managed so much as 0.5 performance recovered on any model that isn’t so tiny and bad it barely has any performance to destroy in the first place. I’d guess most SAEs get something in the range 0.01-0.1 performance recovered or worse.
Note also that getting a good reconstruction score still doesn’t necessarily mean the SAE is actually showing something real and useful. If you want perfect reconstruction, you can just use the standard basis of the network. The SAE would probably also need to be much sparser than the original model activations to provide meaningful insights.
Basically agree—I’m generally a strong supporter of looking at the loss drop in terms of effective compute. Loss recovered using a zero-ablation baseline is really quite wonky and gives misleadingly big numbers.
I also agree that reconstruction is not the only axis of SAE quality we care about. I propose explainability as the other axis—whether we can make necessary and sufficient explanations for when individual latents activate. Progress then looks like pushing this Pareto frontier.
This seems true to me, though finding the right scaling curve for models is typically quite hard so the conversion to effective compute is difficult. I typically use CE loss change, not loss recovered. I think we just don’t know how to evaluate SAE quality.
My personal guess is that SAEs can be a useful interpretability tool despite making a big difference in effective compute, and we should think more in terms of useful they are for downstream tasks. But I agree this is a real phenomena, that is easy to overlook, and is bad.
As a complete noob in all things mechinterp can somebody explain how this is not in conflict with SAE enjoyers saying they get reconstruction loss in the high 90s or even 100 %?
I understand the logscale argument that Lucius is making but still seems surprising ? Is this really what’s going on or are they talking about different things here.
The key question is 90% recovered relative to what. If you recover 90% of the loss relative to a 0 ablation baseline (that ablates the entire residual stream midway though the model!), that isn’t clearly that much.
E.g., if full zero ablation is 13 CE loss (seems plausible) and the SAE gets you to 3 CE while the original model was at 2 CE, this is 90%, but you have also massively degraded performance in terms of effective training compute.
So, it’s a linear scale, and they’re comparing the CE loss increase from inserting the SAE to the CE loss increase from just destroying the model and outputting a ≈ uniform distribution over tokens. The latter is a very large CE loss increase, so the denominator is really big. Thus, scoring over 90% is pretty easy.
Have people done evals for a model with/without an SAE inserted? Seems like even just looking at drops in MMLU performance by category could be non-trivially informative.
I’ve seen a little bit of this, but nowhere near as much as I think the topic merits. I agree that systematic studies on where and how the reconstruction errors make their effects known might be quite informative.
Basically, whenever people train SAEs, or use some other approximate model decomposition that degrades performance, I think they should ideally spend some time after just playing with the degraded model and talking to it. Figure out in what ways it is worse.
On its own, this’d be another metric that doesn’t track the right scale as models become more powerful.
The same KL-div in GPT-2 and GPT-4 probably corresponds to the destruction of far more of the internal structure in the latter than the former.
Destroy 95% of GPT-2′s circuits, and the resulting output distribution may look quite different. Destroy 95% of GPT-4′s circuits, and the resulting output distribution may not be all that different, since 5% of the circuits in GPT-4 might still be enough to get a lot of the most common token prediction cases roughly right.
All current SAEs I’m aware of seem to score very badly on reconstructing the original model’s activations.
If you insert a current SOTA SAE into a language model’s residual stream, model performance on next token prediction will usually degrade down to what a model trained with less than a tenth or a hundredth of the original model’s compute would get. (This is based on extrapolating with Chinchilla scaling curves at optimal compute). And that’s for inserting one SAE at one layer. If you want to study circuits of SAE features, you’ll have to insert SAEs in multiple layers at the same time, potentially further degrading performance.
I think many people outside of interp don’t realize this. Part of the reason they don’t realize it might be that almost all SAE papers report loss reconstruction scores on a linear scale, rather than on a log scale or an LM scaling curve. Going from 1.5 CE loss to 2.0 CE loss is a lot worse than going from 4.5 CE to 5.0 CE. Under the hypothesis that the SAE is capturing some of the model’s ‘features’ and failing to capture others, capturing only 50% or 10% of the features might still only drop the CE loss by a small fraction of a unit.
So, if someone is just glancing at the graphs without looking up what the metrics actually mean, they can be left with the impression that performance is much better than it actually is. The two most common metrics I see are raw CE scores of the model with the SAE inserted, and ‘loss recovered’. I think both of these metrics give a wrong sense of scale. ‘Loss recovered’ is the worse offender, because it makes it outright impossible to tell how good the reconstruction really is without additional information. You need to know what the original model’s loss was and what zero baseline they used to do the conversion. Papers don’t always report this, and the numbers can be cumbersome to find even when they do.
I don’t know what an actually good way to measure model performance drop from SAE insertion is. The best I’ve got is to use scaling curves to guess how much compute you’d need to train a model that gets comparable loss, as suggested here. Or maybe alternatively, training with the same number of tokens as the original model, how many parameters you’d need to get comparable loss. Using this measure, the best reported reconstruction score I’m aware of is 0.1 of the original model’s performance, reached by OpenAI’s GPT-4 SAE with 16 million dictionary elements in this paper.
For most papers, I found it hard to convert their SAE reconstruction scores into this format. So I can’t completely exclude the possibility that some other SAE scores much better. But at this point, I’d be quite surprised if anyone had managed so much as 0.5 performance recovered on any model that isn’t so tiny and bad it barely has any performance to destroy in the first place. I’d guess most SAEs get something in the range 0.01-0.1 performance recovered or worse.
Note also that getting a good reconstruction score still doesn’t necessarily mean the SAE is actually showing something real and useful. If you want perfect reconstruction, you can just use the standard basis of the network. The SAE would probably also need to be much sparser than the original model activations to provide meaningful insights.
Basically agree—I’m generally a strong supporter of looking at the loss drop in terms of effective compute. Loss recovered using a zero-ablation baseline is really quite wonky and gives misleadingly big numbers.
I also agree that reconstruction is not the only axis of SAE quality we care about. I propose explainability as the other axis—whether we can make necessary and sufficient explanations for when individual latents activate. Progress then looks like pushing this Pareto frontier.
This seems true to me, though finding the right scaling curve for models is typically quite hard so the conversion to effective compute is difficult. I typically use CE loss change, not loss recovered. I think we just don’t know how to evaluate SAE quality.
My personal guess is that SAEs can be a useful interpretability tool despite making a big difference in effective compute, and we should think more in terms of useful they are for downstream tasks. But I agree this is a real phenomena, that is easy to overlook, and is bad.
As a complete noob in all things mechinterp can somebody explain how this is not in conflict with SAE enjoyers saying they get reconstruction loss in the high 90s or even 100 %?
I understand the logscale argument that Lucius is making but still seems surprising ? Is this really what’s going on or are they talking about different things here.
The key question is 90% recovered relative to what. If you recover 90% of the loss relative to a 0 ablation baseline (that ablates the entire residual stream midway though the model!), that isn’t clearly that much.
E.g., if full zero ablation is 13 CE loss (seems plausible) and the SAE gets you to 3 CE while the original model was at 2 CE, this is 90%, but you have also massively degraded performance in terms of effective training compute.
IDK about literal 100%.
The metric you mention here is probably ‘loss recovered’. For a residual stream insertion, it goes
1-(CE loss with SAE- CE loss of original model)/(CE loss if the entire residual stream is ablated-CE loss of original model)
See e.g. equation 5 here.
So, it’s a linear scale, and they’re comparing the CE loss increase from inserting the SAE to the CE loss increase from just destroying the model and outputting a ≈ uniform distribution over tokens. The latter is a very large CE loss increase, so the denominator is really big. Thus, scoring over 90% is pretty easy.
Have people done evals for a model with/without an SAE inserted? Seems like even just looking at drops in MMLU performance by category could be non-trivially informative.
I’ve seen a little bit of this, but nowhere near as much as I think the topic merits. I agree that systematic studies on where and how the reconstruction errors make their effects known might be quite informative.
Basically, whenever people train SAEs, or use some other approximate model decomposition that degrades performance, I think they should ideally spend some time after just playing with the degraded model and talking to it. Figure out in what ways it is worse.
Hmmm ok maybe I’ll take a look at this :)
What are your thoughts on KL-div after the unembed softmax as a metric?
On its own, this’d be another metric that doesn’t track the right scale as models become more powerful.
The same KL-div in GPT-2 and GPT-4 probably corresponds to the destruction of far more of the internal structure in the latter than the former.
Destroy 95% of GPT-2′s circuits, and the resulting output distribution may look quite different. Destroy 95% of GPT-4′s circuits, and the resulting output distribution may not be all that different, since 5% of the circuits in GPT-4 might still be enough to get a lot of the most common token prediction cases roughly right.
I don’t see important differences between that and ce loss delta in the context Lucius is describing