As a complete noob in all things mechinterp can somebody explain how this is not in conflict with SAE enjoyers saying they get reconstruction loss in the high 90s or even 100 %?
I understand the logscale argument that Lucius is making but still seems surprising ? Is this really what’s going on or are they talking about different things here.
The key question is 90% recovered relative to what. If you recover 90% of the loss relative to a 0 ablation baseline (that ablates the entire residual stream midway though the model!), that isn’t clearly that much.
E.g., if full zero ablation is 13 CE loss (seems plausible) and the SAE gets you to 3 CE while the original model was at 2 CE, this is 90%, but you have also massively degraded performance in terms of effective training compute.
So, it’s a linear scale, and they’re comparing the CE loss increase from inserting the SAE to the CE loss increase from just destroying the model and outputting a ≈ uniform distribution over tokens. The latter is a very large CE loss increase, so the denominator is really big. Thus, scoring over 90% is pretty easy.
As a complete noob in all things mechinterp can somebody explain how this is not in conflict with SAE enjoyers saying they get reconstruction loss in the high 90s or even 100 %?
I understand the logscale argument that Lucius is making but still seems surprising ? Is this really what’s going on or are they talking about different things here.
The key question is 90% recovered relative to what. If you recover 90% of the loss relative to a 0 ablation baseline (that ablates the entire residual stream midway though the model!), that isn’t clearly that much.
E.g., if full zero ablation is 13 CE loss (seems plausible) and the SAE gets you to 3 CE while the original model was at 2 CE, this is 90%, but you have also massively degraded performance in terms of effective training compute.
IDK about literal 100%.
The metric you mention here is probably ‘loss recovered’. For a residual stream insertion, it goes
1-(CE loss with SAE- CE loss of original model)/(CE loss if the entire residual stream is ablated-CE loss of original model)
See e.g. equation 5 here.
So, it’s a linear scale, and they’re comparing the CE loss increase from inserting the SAE to the CE loss increase from just destroying the model and outputting a ≈ uniform distribution over tokens. The latter is a very large CE loss increase, so the denominator is really big. Thus, scoring over 90% is pretty easy.