Thanks for doing this work—I’m really happy people are doing the basic stress testing of SAEs, and I agree that this is important and urgent given the sheer amount of resources being invested into SAE research.
For me, this was actually a positive update that SAEs are pretty good on distribution—you trained SAE on length 128 sequences from OpenWebText, and the log loss was quite low up to ~200 tokens! This is despite its poor downstream use case performance.
I expected to see more negative results along the lines of your Lambada and Children’s Book test results (that is, substantial degradation of loss, as soon as you go a tiny bit off distribution):
I do think these results add on to the growing pile of evidence that SAEs are not good “off distribution” (even a small amount off distribution, as in Sam Marks’s results you link). This means they’re somewhat problematic for OOD use cases like treacherous turn detection or detecting misgeneralization. That doesn’t mean they’re useless—e.g. it’s plausible that SAEs could be useful for steering, mechanistic anomaly detection, or helping us do case analysis for heuristic arguments or proofs.
As an aside, am I reading this plot incorrectly, or does the figure on the right suggest that SAE reconstructed representations have lower log loss than the original unmodified model?
For me, this was actually a positive update that SAEs are pretty good on distribution—you trained SAE on length 128 sequences from OpenWebText, and the log loss was quite low up to ~200 tokens! This is despite its poor downstream use case performance.
Yes, this was nice to see. I originally just looked at context positions at powers of 2 (...64, 128, 256,...) and there everything looked terrible above 128, but Logan recommended looking at all context positions and this was a cool result!
But note that there’s a layer effect here. I think layer 12 is good up to ~200 tokens while layer 0 is only really good up to the training context size. I think this is most clear in the MSE/L1 plots (and this is consistent with later layers performing ok-ish on the long context CBT while early layers are poor).
This means they’re somewhat problematic for OOD use cases like treacherous turn detection or detecting misgeneralization.
Yeah, agreed, which is a bummer because that’s one thing I’d really like to see SAEs enable! Wonder if there’s a way to change the training of SAEs to shift this over to on-distribution where they perform well.
As an aside, am I reading this plot incorrectly, or does the figure on the right suggest that SAE reconstructed representations have lower log loss than the original unmodified model?
Oof, that figure does indeed suggest that, but it’s because of a bug in the plot script. Thank you for pointing that out, here’s a fixed version:
I’ve fixed the repo and I’ll edit the original post shortly.
This means they’re somewhat problematic for OOD use cases like treacherous turn detection or detecting misgeneralization.
I kinda want to push back on this since OOD in behavior is not obviously OOD in the activations. Misgeneralization especially might be better thought of as an OOD environment and on-distribution activations?
I think we should come back to this question when SAEs have tackled something like variable binding with SAEs. Right now it’s hard to say how SAEs are going to help us understand more abstract thinking and therefore I think it’s hard to say how problematic they’re going to be for detecting things like a treacherous turn. I think this will depend on how how representations factor. In the ideal world, they generalize with the model’s ability to generalize (Apologies for how high level / vague that idea is).
Some experiments I’d be excited to look at:
If the SAE is trained on a subset of the training distribution, can we distinguish it being used to decompose activations on those data points off the training distribution?
How does that compare to an SAE trained on the whole training distribution from the model, but then looking at when the model is being pushed off distribution?
Thanks for doing this work—I’m really happy people are doing the basic stress testing of SAEs, and I agree that this is important and urgent given the sheer amount of resources being invested into SAE research.
For me, this was actually a positive update that SAEs are pretty good on distribution—you trained SAE on length 128 sequences from OpenWebText, and the log loss was quite low up to ~200 tokens! This is despite its poor downstream use case performance.
I expected to see more negative results along the lines of your Lambada and Children’s Book test results (that is, substantial degradation of loss, as soon as you go a tiny bit off distribution):
I do think these results add on to the growing pile of evidence that SAEs are not good “off distribution” (even a small amount off distribution, as in Sam Marks’s results you link). This means they’re somewhat problematic for OOD use cases like treacherous turn detection or detecting misgeneralization. That doesn’t mean they’re useless—e.g. it’s plausible that SAEs could be useful for steering, mechanistic anomaly detection, or helping us do case analysis for heuristic arguments or proofs.
As an aside, am I reading this plot incorrectly, or does the figure on the right suggest that SAE reconstructed representations have lower log loss than the original unmodified model?
Yes, this was nice to see. I originally just looked at context positions at powers of 2 (...64, 128, 256,...) and there everything looked terrible above 128, but Logan recommended looking at all context positions and this was a cool result!
But note that there’s a layer effect here. I think layer 12 is good up to ~200 tokens while layer 0 is only really good up to the training context size. I think this is most clear in the MSE/L1 plots (and this is consistent with later layers performing ok-ish on the long context CBT while early layers are poor).
Yeah, agreed, which is a bummer because that’s one thing I’d really like to see SAEs enable! Wonder if there’s a way to change the training of SAEs to shift this over to on-distribution where they perform well.
Oof, that figure does indeed suggest that, but it’s because of a bug in the plot script. Thank you for pointing that out, here’s a fixed version:
I’ve fixed the repo and I’ll edit the original post shortly.
I kinda want to push back on this since OOD in behavior is not obviously OOD in the activations. Misgeneralization especially might be better thought of as an OOD environment and on-distribution activations?
I think we should come back to this question when SAEs have tackled something like variable binding with SAEs. Right now it’s hard to say how SAEs are going to help us understand more abstract thinking and therefore I think it’s hard to say how problematic they’re going to be for detecting things like a treacherous turn. I think this will depend on how how representations factor. In the ideal world, they generalize with the model’s ability to generalize (Apologies for how high level / vague that idea is).
Some experiments I’d be excited to look at:
If the SAE is trained on a subset of the training distribution, can we distinguish it being used to decompose activations on those data points off the training distribution?
How does that compare to an SAE trained on the whole training distribution from the model, but then looking at when the model is being pushed off distribution?
I think I’m trying to get at—can we distinguish:
Anomalous activations.
Anomalous data points.
Anomalous mechanisms.
Lots of great work to look forward to!