Our reconstruction scores were pretty good. We found GPT2 small achieves a cross entropy loss of about 3.3, and with reconstructed activations in place of the original activation, the CE Log Loss stays below 3.6.
Unless my memory is screwing up the scale here, 0.3 CE Loss increase seems quite substantial? A 0.3 CE loss increase on the pile is roughly the difference between Pythia 410M and Pythia 2.8B. And do I see it right that this is the CE increase maximum for adding in one SAE, rather than all of them at the same time? So unless there is some very kind correlation in these errors where every SAE is failing to reconstruct roughly the same variance, and that variance at early layers is not used to compute the variance SAEs at later layers are capturing, the errors would add up? Possibly even worse than linearly? What CE loss do you get then?
Have you tried talking to the patched models a bit and compared to what the original model sounds like? Any discernible systematic differences in where that CE increase is changing the answers?
My personal guess is that something like this is probably true. However since we’re comparing OpenWebText and the Pile and different tokenizers, we can’t really compare the two loss numbers, and further there is not GPT-2 extra small model so currently we can’t compare these SAEs to smaller models. But yeah in future we will probably compare GPT-2 Medium and GPT-2 Large with SAEs attached to the smaller models in the same family, and there will probably be similar degradation at least until we have more SAE advances.
One experiment here is to see if specific datapoints that have worse CE-diff correlate across layers. Last time I did a similar experiment, I saw a very long tail of datapoints that were worse off (for just one layer of gpt2-small), but the majority of datapoints had similar CE. So Joseph’s suggested before to UMAP these datapoints & color by their CE-diff (or other methods to see if you could separate out these datapoints).
If someone were to run this experiment, I’d also be interested if you removed the k-lowest features per datapoint, checking the new CE & MSE. In the SAE-work, the lowest activating features usually don’t make sense for the datapoint. This is to test the hypothesis:
Low-activating features are noise or some acceptable false alarm rate true to the LLM2. (ie SAE’s capture what we care about)
Actually they’re important for CE in ways we don’t understand. (ie SAE’s let in un-interpretable feature activations which are important, but?)
For example, if you saw better CE-diff when removing low-activating features, up to a specific k, then SAE’s are looking good!
Neel and I recently tried to interpret a language model circuit by attaching SAEs to the model. We found that using an L0=50 SAE while only keeping the top 10 features by activation value per prompt (and zero ablating the others) was better than an L0=10 SAE by our task-specific metric, and subjective interpretability. I can check how far this generalizes.
I’d be excited about reading about / or doing these kinds of experiments. My weak prediction is that low activating features are important in specific examples where nuance matters and that what we want is something like an “adversarially robust SAE” which might only be feasible with current SAE methods on a very narrow distribution.
A mini experiment I did which motivates this: I did an experiment with an SAE at the residual stream where I looked at the attention pattern of an attention head immediately following the head as function of k, where we take the top-k SAE features in the reconstruction. I found that if the head was attending to “Mary” in the original forward pass (and not “John”), then a k of 3 was good enough to have it attend to Mary and not John. But if I replaced John with Martha, the minimum k such that the head attended to Mary increased.
worse ratio of reconstruction-L2/original-L2 (meaning it’s under-normed)*
less dead features (maybe they need more features?)
For (3), we might expect under-normed reconstructions because there’s a trade-off between L1 & MSE. After training, however, we can freeze the encoder, locking in the L0, and train on the decoder or scalar multiples of the hidden layer (h/t to Ben Wright for first figuring this out).
(4) Seems like a pretty easy experiment to try to just vary num of features to see if this explains part of the gap.
Thanks for raising this! I had wanted to find a comparison in terms of different model performances to help me quantify this so I’m glad to have this as a reference.
And do I see it right that this is the CE increase maximum for adding in one SAE, rather than all of them at the same time? So unless there is some very kind correlation in these errors where every SAE is failing to reconstruct roughly the same variance, and that variance at early layers is not used to compute the variance SAEs at later layers are capturing, the errors would add up? Possibly even worse than linearly? What CE loss do you get then?
Have you tried talking to the patched models a bit and compared to what the original model sounds like? Any discernible systematic differences in where that CE increase is changing the answers?
While I have explored model performance with SAEs at different layers, I haven’t done so with more than one SAE or explored sampling from the model with the SAE. I’ve been curious about systematic errors induced by the SAE but a few brief experiments with earlier SAEs/smaller models didn’t reveal any obvious patterns. I have once or twice looked at the divergence in the activations after an SAE has been added and found that errors in earlier layers propagated.
One thought I have on this is that if we take the analogy to DNA sequencing seriously, relatively minor errors in DNA sequencing make the resulting sequences useless. If you get one or two base pairs wrong then try to make bacteria express the printed gene (based on your sequencing) then you’ll kill that bacteria. This gives me the intuition that I absolutely expect we could have fairly accurate measurements with some error and that the resulting error is large.
To bring it back to what I suspect is the main point here: We should amend the statement to say “Our reconstruction scores were pretty good as compared to our previous results”.
It bothers me quite a bit that SAEs don’t recover performance better, but I think this is a fairly well defined and that the community can iterate on both via improvements to SAEs and looking for nearby alternatives. For example, I’m quite excited to experiment with any alternative architectures/training procedures that come out of the theory of computation in superposition line of work.
One productive direction inspired by thinking of this as sequencing is that we should have lots of SAEs trained on the same model and show that they get very similar results (to give us more confidence we have a better estimate of the true underlying features). It’s standard in DNA/RNA/Protein sequencing to run methods many times over. I think once we see evidence that we get good results along those lines, we should be more interested in / raise our standards for model performance with reconstructed SAEs.
Is the drop of eval loss when attaching SAEs a crux for the SAE research direction to you? I agree it’s not ideal, but to me the comparison of eval loss to smaller models only makes sense if the goal of the SAE direction is making a full-distribution competitive model. Explaining narrow tasks, or just using SAEs for monitoring/steering/lie detection/etc. doesn’t require competitive eval loss. (Note that I have varying excitement about all these goals, e.g. some pessimism about steering)
If the SAEs are not full-distribution competitive, I don’t really trust that the features they’re seeing are actually the variables being computed on in the sense of reflecting the true mechanistic structure of the learned network algorithm and that the explanations they offer are correct[1]. If I pick a small enough sub-distribution, I can pretty much always get perfect reconstruction no matter what kind of probe I use, because e.g. measured over a single token the network layers will have representation rank 1, and the entire network can be written as a rank-1 linear transform. So I can declare the activation vector at layer l to be the active “feature”, use the single entry linear maps between SAEs to “explain” how features between layers map to each other, and be done. Those explanations will of course be nonsense and not at all extrapolate out of distributon. I can’t use them to make a causal model that accurately reproduces the network’s behavior or some aspect of it when dealing with a new prompt.
We don’t train SAEs on literally single tokens, but I would be worried about the qualitative problem persisting. The network itself doesn’t have a million different algorithms to perform a million different narrow subtasks. It has a finite description length. It’s got to be using a smaller set of general algorithms that handle all of these different subtasks, at least to some extent. Likely more so for more powerful and general networks. If our “explanations” of the network then model it in terms of different sets of features and circuits for different narrow subtasks that don’t fit together coherently to give a single good reconstruction loss over the whole distribution, that seems like a sign that our SAE layer activations didn’t actually capture the general algorithms in the network. Thus, predictions about network behaviour made on the basis of inspecting causal relationships between these SAE activations might not be at all reliable, especially predictions about behaviours like instrumental deception which might be very mechanistically related to how the network does well on cross-domain generalisation.
As in, that seems like a minimum requirement for the SAEs to fulfil. Not that this would be enough to make me trust predictions about generalisation based on stories about SAE activations.
> The network itself doesn’t have a million different algorithms to perform a million different narrow subtasks
For what it’s worth, this sort of thinking is really not obvious to me at all. It seems very plausible that frontier models only have their amazing capabilities through the aggregation of a huge number of dumb heuristics (as an aside, I think if true this is net positive for alignment). This is consistent with findings that e.g. grokking and phase changes are much less common in LLMs than toy models.
(Two objections to these claims are that plausibly current frontier models are importantly limited, and also that it’s really hard to prove either me or you correct on this point since it’s all hand-wavy)
Thanks for the first sentence—I appreciate clearly stating a position.
measured over a single token the network layers will have representation rank 1
I don’t follow this. Are you saying that the residual stream at position 0 in a transformer is a function of the first token only, or something like this?
If so, I agree—but I don’t see how this applies to much SAE[1] or mech interp[2] work. Where do we disagree?
E.g. in this post here we show in detail how an “inside a question beginning with which” SAE feature is computed from which and predicts question marks (I helped with this project but didn’t personally find this feature)
More generally, in narrow distribution mech interp work such as the IOI paper, I don’t think it makes sense to reduce the explanation to single-token perfect accuracy probes since our explanation generalises fairly well (e.g. the “Adversarial examples” in Section 4.4 Alexandre found, for example)
Unless my memory is screwing up the scale here, 0.3 CE Loss increase seems quite substantial? A 0.3 CE loss increase on the pile is roughly the difference between Pythia 410M and Pythia 2.8B. And do I see it right that this is the CE increase maximum for adding in one SAE, rather than all of them at the same time? So unless there is some very kind correlation in these errors where every SAE is failing to reconstruct roughly the same variance, and that variance at early layers is not used to compute the variance SAEs at later layers are capturing, the errors would add up? Possibly even worse than linearly? What CE loss do you get then?
Have you tried talking to the patched models a bit and compared to what the original model sounds like? Any discernible systematic differences in where that CE increase is changing the answers?
> 0.3 CE Loss increase seems quite substantial? A 0.3 CE loss increase on the pile is roughly the difference between Pythia 410M and Pythia 2.8B
My personal guess is that something like this is probably true. However since we’re comparing OpenWebText and the Pile and different tokenizers, we can’t really compare the two loss numbers, and further there is not GPT-2 extra small model so currently we can’t compare these SAEs to smaller models. But yeah in future we will probably compare GPT-2 Medium and GPT-2 Large with SAEs attached to the smaller models in the same family, and there will probably be similar degradation at least until we have more SAE advances.
One experiment here is to see if specific datapoints that have worse CE-diff correlate across layers. Last time I did a similar experiment, I saw a very long tail of datapoints that were worse off (for just one layer of gpt2-small), but the majority of datapoints had similar CE. So Joseph’s suggested before to UMAP these datapoints & color by their CE-diff (or other methods to see if you could separate out these datapoints).
If someone were to run this experiment, I’d also be interested if you removed the k-lowest features per datapoint, checking the new CE & MSE. In the SAE-work, the lowest activating features usually don’t make sense for the datapoint. This is to test the hypothesis:
Low-activating features are noise or some acceptable false alarm rate true to the LLM2. (ie SAE’s capture what we care about)
Actually they’re important for CE in ways we don’t understand. (ie SAE’s let in un-interpretable feature activations which are important, but?)
For example, if you saw better CE-diff when removing low-activating features, up to a specific k, then SAE’s are looking good!
Neel and I recently tried to interpret a language model circuit by attaching SAEs to the model. We found that using an L0=50 SAE while only keeping the top 10 features by activation value per prompt (and zero ablating the others) was better than an L0=10 SAE by our task-specific metric, and subjective interpretability. I can check how far this generalizes.
I’d be excited about reading about / or doing these kinds of experiments. My weak prediction is that low activating features are important in specific examples where nuance matters and that what we want is something like an “adversarially robust SAE” which might only be feasible with current SAE methods on a very narrow distribution.
A mini experiment I did which motivates this: I did an experiment with an SAE at the residual stream where I looked at the attention pattern of an attention head immediately following the head as function of k, where we take the top-k SAE features in the reconstruction. I found that if the head was attending to “Mary” in the original forward pass (and not “John”), then a k of 3 was good enough to have it attend to Mary and not John. But if I replaced John with Martha, the minimum k such that the head attended to Mary increased.
There’s a few things to note. Later layers have:
worse CE-diff & variance explained (e.g. the layer 0 CE-diff seems great!)
larger L2 norms in the original LLM activations
worse ratio of reconstruction-L2/original-L2 (meaning it’s under-normed)*
less dead features (maybe they need more features?)
For (3), we might expect under-normed reconstructions because there’s a trade-off between L1 & MSE. After training, however, we can freeze the encoder, locking in the L0, and train on the decoder or scalar multiples of the hidden layer (h/t to Ben Wright for first figuring this out).
(4) Seems like a pretty easy experiment to try to just vary num of features to see if this explains part of the gap.
*
Thanks for raising this! I had wanted to find a comparison in terms of different model performances to help me quantify this so I’m glad to have this as a reference.
While I have explored model performance with SAEs at different layers, I haven’t done so with more than one SAE or explored sampling from the model with the SAE. I’ve been curious about systematic errors induced by the SAE but a few brief experiments with earlier SAEs/smaller models didn’t reveal any obvious patterns. I have once or twice looked at the divergence in the activations after an SAE has been added and found that errors in earlier layers propagated.
One thought I have on this is that if we take the analogy to DNA sequencing seriously, relatively minor errors in DNA sequencing make the resulting sequences useless. If you get one or two base pairs wrong then try to make bacteria express the printed gene (based on your sequencing) then you’ll kill that bacteria. This gives me the intuition that I absolutely expect we could have fairly accurate measurements with some error and that the resulting error is large.
To bring it back to what I suspect is the main point here: We should amend the statement to say “Our reconstruction scores were pretty good as compared to our previous results”.
It bothers me quite a bit that SAEs don’t recover performance better, but I think this is a fairly well defined and that the community can iterate on both via improvements to SAEs and looking for nearby alternatives. For example, I’m quite excited to experiment with any alternative architectures/training procedures that come out of the theory of computation in superposition line of work.
One productive direction inspired by thinking of this as sequencing is that we should have lots of SAEs trained on the same model and show that they get very similar results (to give us more confidence we have a better estimate of the true underlying features). It’s standard in DNA/RNA/Protein sequencing to run methods many times over. I think once we see evidence that we get good results along those lines, we should be more interested in / raise our standards for model performance with reconstructed SAEs.
Is the drop of eval loss when attaching SAEs a crux for the SAE research direction to you? I agree it’s not ideal, but to me the comparison of eval loss to smaller models only makes sense if the goal of the SAE direction is making a full-distribution competitive model. Explaining narrow tasks, or just using SAEs for monitoring/steering/lie detection/etc. doesn’t require competitive eval loss. (Note that I have varying excitement about all these goals, e.g. some pessimism about steering)
If the SAEs are not full-distribution competitive, I don’t really trust that the features they’re seeing are actually the variables being computed on in the sense of reflecting the true mechanistic structure of the learned network algorithm and that the explanations they offer are correct[1]. If I pick a small enough sub-distribution, I can pretty much always get perfect reconstruction no matter what kind of probe I use, because e.g. measured over a single token the network layers will have representation rank 1, and the entire network can be written as a rank-1 linear transform. So I can declare the activation vector at layer l to be the active “feature”, use the single entry linear maps between SAEs to “explain” how features between layers map to each other, and be done. Those explanations will of course be nonsense and not at all extrapolate out of distributon. I can’t use them to make a causal model that accurately reproduces the network’s behavior or some aspect of it when dealing with a new prompt.
We don’t train SAEs on literally single tokens, but I would be worried about the qualitative problem persisting. The network itself doesn’t have a million different algorithms to perform a million different narrow subtasks. It has a finite description length. It’s got to be using a smaller set of general algorithms that handle all of these different subtasks, at least to some extent. Likely more so for more powerful and general networks. If our “explanations” of the network then model it in terms of different sets of features and circuits for different narrow subtasks that don’t fit together coherently to give a single good reconstruction loss over the whole distribution, that seems like a sign that our SAE layer activations didn’t actually capture the general algorithms in the network. Thus, predictions about network behaviour made on the basis of inspecting causal relationships between these SAE activations might not be at all reliable, especially predictions about behaviours like instrumental deception which might be very mechanistically related to how the network does well on cross-domain generalisation.
As in, that seems like a minimum requirement for the SAEs to fulfil. Not that this would be enough to make me trust predictions about generalisation based on stories about SAE activations.
(This reply is less important than my other)
> The network itself doesn’t have a million different algorithms to perform a million different narrow subtasks
For what it’s worth, this sort of thinking is really not obvious to me at all. It seems very plausible that frontier models only have their amazing capabilities through the aggregation of a huge number of dumb heuristics (as an aside, I think if true this is net positive for alignment). This is consistent with findings that e.g. grokking and phase changes are much less common in LLMs than toy models.
(Two objections to these claims are that plausibly current frontier models are importantly limited, and also that it’s really hard to prove either me or you correct on this point since it’s all hand-wavy)
Thanks for the first sentence—I appreciate clearly stating a position.
I don’t follow this. Are you saying that the residual stream at position 0 in a transformer is a function of the first token only, or something like this?
If so, I agree—but I don’t see how this applies to much SAE[1] or mech interp[2] work. Where do we disagree?
E.g. in this post here we show in detail how an “inside a question beginning with which” SAE feature is computed from which and predicts question marks (I helped with this project but didn’t personally find this feature)
More generally, in narrow distribution mech interp work such as the IOI paper, I don’t think it makes sense to reduce the explanation to single-token perfect accuracy probes since our explanation generalises fairly well (e.g. the “Adversarial examples” in Section 4.4 Alexandre found, for example)