Good work, I’m glad that people are exploring this empirically.
That being said, I’m not sure that these results tell us very much about whether or not the MCIS theory is correct. In fact, something like your results should hold as long as the following facts are true (even without superposition):
Correct behavior: The model behavior is correct on distribution, and the correct behavior isn’t super sensitive to many small variations to the input.
Linear feature representations: The model encodes information along particular directions, and “reads-off” the information along these directions when deciding what to do.
If these are true, then I think the results you get follow:
Activation plateaus: If the model’s behavior changes a lot for actual on-distribution examples, then it’s probably wrong, because there’s lots of similar seeming examples (which won’t lead to exactly the same activation, but will lead to similar ones) where the model should behave similarly. For example, given a fixed MMLU problem and a few different sets of 5-shot examples, the activations will likely be close but won’t be the same, (as the inputs are similar and the relevant information to locating the task should be the same). But if the model performs uses the 5-shot examples to get the correct answer, its logits can’t change too much as a function of the inputs.
In general, we’d expect to see plateaus around any real examples, because the correct behavior doesn’t change that much as a function of small variations to the input, and the model performs well. In contrast, for activations that are very off distribution for the model, there is no real reason for the model to remain consistent across small perturbations.
Sensitive directions: Most directions in high-dimensional space are near-orthogonal, so by default random small perturbations don’t change the read-off along any particular direction by very much. But if you perturb the activation along some of the read-off directions, then this will indeed change the magnitude along each of these directions a lot!
Local optima in sensitivity: Same explanation as with sensitive directions.
Note that we don’t need superposition to explain any of these results. So I don’t think these results really support one model of superposition via the other, given they seem to follow from a combination of model behaving correctly and the linear representation hypothesis.
Instead, I see your results as primarily a sanity-check of your techniques for measuring activation plateaus and for measuring sensitivity to directions, as opposed to weighing in on particular theories of superposition. I’d be interested in seeing the techniques applied to other tasks, such as validating the correctness of SAE features.
I agree this doesn’t distinguish superposition vs no superposition at all; I was more thinking about the “error correction” aspect of MCIS (and just assuming superposition to be true). But I’m excited too for the SAE application, we got some experiments in the pipeline!
Your Correct behaviour point sounds reasonable but I feel like it’s not an explanation? I would have the same intuitive expectation, but that doesn’t explain how the model manages to not be sensitive. Explanations I can think of in increasing order of probability:
Story 0: Perturbations change activations and logprobs, but the answer doesn’t change because the logprob difference was large. I don’t think the KL divergence would behave like that.
Story 1: Perturbations do change the activations but the difference in the logprobs is small due to layer norm, unembed, or softmax shenanigans.
We did a test-experiment of perturbing the 12th layer rather than the 2nd layer, and the difference between real-other and random disappeared. So I don’t think it’s a weird effect when activations get converted to outputs.
Story 2: Perturbations in a lower layer cause less perturbation in later layers if the model is on-distribution (+ similar story for sensitivity).
This is what the L2-metric plots (right panel) suggest, and also what I understand your story to be.
But this doesn’t explain how the model does this, right? Are there simple stories how this happens?
I guess there’s lots of stories not limited to MCIS, anything along the lines of “ReLUs require thresholds to be passed”?
Based on that, I think the results still require some “error-correction” explanation, though you’re right that this doesn’t have to me MCIS (it’s just that there’s no other theory that doesn’t also conflict with superposition?).
Good work, I’m glad that people are exploring this empirically.
That being said, I’m not sure that these results tell us very much about whether or not the MCIS theory is correct. In fact, something like your results should hold as long as the following facts are true (even without superposition):
Correct behavior: The model behavior is correct on distribution, and the correct behavior isn’t super sensitive to many small variations to the input.
Linear feature representations: The model encodes information along particular directions, and “reads-off” the information along these directions when deciding what to do.
If these are true, then I think the results you get follow:
Activation plateaus: If the model’s behavior changes a lot for actual on-distribution examples, then it’s probably wrong, because there’s lots of similar seeming examples (which won’t lead to exactly the same activation, but will lead to similar ones) where the model should behave similarly. For example, given a fixed MMLU problem and a few different sets of 5-shot examples, the activations will likely be close but won’t be the same, (as the inputs are similar and the relevant information to locating the task should be the same). But if the model performs uses the 5-shot examples to get the correct answer, its logits can’t change too much as a function of the inputs.
In general, we’d expect to see plateaus around any real examples, because the correct behavior doesn’t change that much as a function of small variations to the input, and the model performs well. In contrast, for activations that are very off distribution for the model, there is no real reason for the model to remain consistent across small perturbations.
Sensitive directions: Most directions in high-dimensional space are near-orthogonal, so by default random small perturbations don’t change the read-off along any particular direction by very much. But if you perturb the activation along some of the read-off directions, then this will indeed change the magnitude along each of these directions a lot!
Local optima in sensitivity: Same explanation as with sensitive directions.
Note that we don’t need superposition to explain any of these results. So I don’t think these results really support one model of superposition via the other, given they seem to follow from a combination of model behaving correctly and the linear representation hypothesis.
Instead, I see your results as primarily a sanity-check of your techniques for measuring activation plateaus and for measuring sensitivity to directions, as opposed to weighing in on particular theories of superposition. I’d be interested in seeing the techniques applied to other tasks, such as validating the correctness of SAE features.
Thanks for the comment Lawrence, I appreciate it!
I agree this doesn’t distinguish superposition vs no superposition at all; I was more thinking about the “error correction” aspect of MCIS (and just assuming superposition to be true). But I’m excited too for the SAE application, we got some experiments in the pipeline!
Your Correct behaviour point sounds reasonable but I feel like it’s not an explanation? I would have the same intuitive expectation, but that doesn’t explain how the model manages to not be sensitive. Explanations I can think of in increasing order of probability:
Story 0: Perturbations change activations and logprobs, but the answer doesn’t change because the logprob difference was large. I don’t think the KL divergence would behave like that.
Story 1: Perturbations do change the activations but the difference in the logprobs is small due to layer norm, unembed, or softmax shenanigans.
We did a test-experiment of perturbing the 12th layer rather than the 2nd layer, and the difference between real-other and random disappeared. So I don’t think it’s a weird effect when activations get converted to outputs.
Story 2: Perturbations in a lower layer cause less perturbation in later layers if the model is on-distribution (+ similar story for sensitivity).
This is what the L2-metric plots (right panel) suggest, and also what I understand your story to be.
But this doesn’t explain how the model does this, right? Are there simple stories how this happens?
I guess there’s lots of stories not limited to MCIS, anything along the lines of “ReLUs require thresholds to be passed”?
Based on that, I think the results still require some “error-correction” explanation, though you’re right that this doesn’t have to me MCIS (it’s just that there’s no other theory that doesn’t also conflict with superposition?).