Really excited to see this come out! I’m in generally very excited to see work trying to make mechanistic interpretability more rigorous/coherent/paradigmatic, and think causal scrubbing is a pretty cool idea, though have some concerns that it sets the bar too high for something being a legit circuit. The part that feels most conceptually elegant to me is the idea that an interpretability hypothesis allows certain inputs to be equivalent for getting a certain answer (and the null hypothesis says that no inputs are equivalent), and then the recursive algorithm to zoom in and ask which inputs should be equivalent on a particular component.
I’m excited to see how this plays out at REMIX, in particular how much causal scrubbing can be turned into an exploratory tool to find circuits rather than just to verify them (and also how often well-meaning people can find false positives).
This sequence is pretty long, so if it helps people, here’s a summary of causal scrubbing I wrote for a mechanistic interpretability glossary that I’m writing (please let me know if anything in here is inaccurate)
Redwood Research have suggested that the right way to think about circuits is actually to think of the model as a computational graph. In a transformer, nodes are components of the model, ie attention heads and neurons (in MLP layers), and edges between nodes are the part of input to the later node that comes from the output of the previous node. Within this framework, a circuit is a computational subgraph—a subset of nodes and a subset of the edges between them that is sufficient for doing the relevant computation.
The key facts about transformer that make this framework work is that the output of each layer is the sum of the output of each component, and the input to each layer (the residual stream) is the sum of the output of every previous layer and thus the sum of the output of every previous component.
Note: This means that there is an edge into a component from every component in earlier layers
And because the inputs are the sum of the output of each component, we can often cleanly consider subsets of nodes and edges—this is linear and it’s easy to see the effect of adding and removing terms.
The differences with the above framing are somewhat subtle:
In the features framing, we don’t necessarily assume that features are aligned with circuit components (eg, they could be arbitrary directions in neuron space), while in the subgraph framing we focus on components and don’t need to show that the components correspond to features
It’s less obvious how to think about an attention head as “representing a feature”—in some intuitive sense heads are “larger” than neurons—eg their output space lies in a rank d_head subspace, rather than just being a direction. The subgraph framing side-steps this.
Causal scrubbing: An algorithm being developed by Redwood Research that tries to create an automated metric for deciding whether a computational subgraph corresponds to a circuit.
(The following is my attempt at a summary—if you get confused, go check out their 100 page doc…)
The exact algorithm is pretty involved and convoluted, but the key idea is to think of an interpretability hypothesis as saying which parts of a model don’t matter for a computation.
The null hypothesis is that everything matters (ie, the state of knowing nothing about a model).
Let’s take the running example of an induction circuit, which predicts repeated subsequences. We take a sequence … A B … A (A, B arbitrary tokens) and output B as the next token. Our hypothesis is that this is done by a previous token head, which notices that A1 is before B, and then an induction head, which looks from the destination token A2 to source tokens who’s previous token is A (ie B), and predicts that the value of whatever token it’s looking at (ie B) will come next.
If a part of a model doesn’t matter, we should be able to change it without changing the model output. Their favoured tool for doing this is a random ablation, ie replacing the output of that model component with its output on a different, randomly chosen input. (See later for motivation).
The next step is that we can be specific about which parts of the input matter for each relevant component.
So, eg, we should be able to replace the output of the previous token head with any sequence with an A in that position, if we think that that’s all it depends on. And this sequence can be different from the input sequence that the input head sees, so long as the first A token agrees.
There are various ways to make this even more specific that they discuss, eg separately editing the key, value and query inputs to a head.
The final step is to take a metric for circuit quality—they use the expected loss recovered, ie “what fraction of the expected loss on the subproblem we’re studying does our scrubbed circuit recover, compared to the original model with no edits”
in particular how much causal scrubbing can be turned into an exploratory tool to find circuits rather than just to verify them
I’d like to flag that this has been pretty easy to do—for instance, this process can look like resample ablating different nodes of the computational graph (eg each attention head/MLP), finding the nodes that when ablated most impact the model’s performance and are hence important, and then recursively searching for nodes that are relevant to the current set of important nodes by ablating nodes upstream to each important node.
Nice summary! One small nitpick: > In the features framing, we don’t necessarily assume that features are aligned with circuit components (eg, they could be arbitrary directions in neuron space), while in the subgraph framing we focus on components and don’t need to show that the components correspond to features
This feels slightly misleading. In practice, we often do claim that sub-components correspond to features. We can “rewrite” our model into an equivalent form that better reflects the computation it’s performing. For example, if we claim that a certain direction in an MLP’s output is important, we could rewrite the single MLP node as the sum of the MLP output in the direction + the residual term. Then, we could make claims about the direction we pointed out and also claim that the residual term is unimportant.
The important point is that we are allowed to rewrite our model however we want as long as the rewrite is equivalent.
Thanks for the clarification! If I’m understanding correctly, you’re saying that the important part is decomposing activations (linearly?) and that there’s nothing really baked in about what a component can and cannot be. You normally focus on components, but this can also fully encompass the features as directions frame, by just saying that “the activation component in that direction” is a feature?
a non-linear decomposition as f(x) is an arbitrary function.
Regardless, any decomposition into a computational graph (that we can prove is extensionally equal) is fine.
For instance, if it’s the case that MLP(x) = combine(h(x), g(x)) (via extensional equality), then I can scrub h(x) and g(x) individually.
One example of this could be a product, e.g, suppose that MLP(x) = h(x) * g(x) (maybe like swiglu or something).
We haven’t had to use a non-linear decomposition in our interp work so far at Redwood. Just wanted to point out that it’s possible. I’m not sure when you would want to use one, but I haven’t thought about it that much.
Really excited to see this come out! I’m in generally very excited to see work trying to make mechanistic interpretability more rigorous/coherent/paradigmatic, and think causal scrubbing is a pretty cool idea, though have some concerns that it sets the bar too high for something being a legit circuit. The part that feels most conceptually elegant to me is the idea that an interpretability hypothesis allows certain inputs to be equivalent for getting a certain answer (and the null hypothesis says that no inputs are equivalent), and then the recursive algorithm to zoom in and ask which inputs should be equivalent on a particular component.
I’m excited to see how this plays out at REMIX, in particular how much causal scrubbing can be turned into an exploratory tool to find circuits rather than just to verify them (and also how often well-meaning people can find false positives).
This sequence is pretty long, so if it helps people, here’s a summary of causal scrubbing I wrote for a mechanistic interpretability glossary that I’m writing (please let me know if anything in here is inaccurate)
I’d like to flag that this has been pretty easy to do—for instance, this process can look like resample ablating different nodes of the computational graph (eg each attention head/MLP), finding the nodes that when ablated most impact the model’s performance and are hence important, and then recursively searching for nodes that are relevant to the current set of important nodes by ablating nodes upstream to each important node.
Exciting! I look forward to the first “interesting circuit entirely derived by causal scrubbing” paper
Nice summary! One small nitpick:
> In the features framing, we don’t necessarily assume that features are aligned with circuit components (eg, they could be arbitrary directions in neuron space), while in the subgraph framing we focus on components and don’t need to show that the components correspond to features
This feels slightly misleading. In practice, we often do claim that sub-components correspond to features. We can “rewrite” our model into an equivalent form that better reflects the computation it’s performing. For example, if we claim that a certain direction in an MLP’s output is important, we could rewrite the single MLP node as the sum of the MLP output in the direction + the residual term. Then, we could make claims about the direction we pointed out and also claim that the residual term is unimportant.
The important point is that we are allowed to rewrite our model however we want as long as the rewrite is equivalent.
Thanks for the clarification! If I’m understanding correctly, you’re saying that the important part is decomposing activations (linearly?) and that there’s nothing really baked in about what a component can and cannot be. You normally focus on components, but this can also fully encompass the features as directions frame, by just saying that “the activation component in that direction” is a feature?
Yes! The important part is decomposing activations (not neccessarily linearly). I can rewrite my MLP as:
MLP(x) = f(x) + (MLP(x) - f(x))
and then claim that the MLP(x) - f(x) term is unimportant. There is an example of this in the parentheses balancer example.
Thanks! Can you give a non-linear decomposition example?
I would typically call
MLP(x) = f(x) + (MLP(x) - f(x))
a non-linear decomposition as f(x) is an arbitrary function.
Regardless, any decomposition into a computational graph (that we can prove is extensionally equal) is fine. For instance, if it’s the case that MLP(x) = combine(h(x), g(x)) (via extensional equality), then I can scrub h(x) and g(x) individually.
One example of this could be a product, e.g, suppose that MLP(x) = h(x) * g(x) (maybe like swiglu or something).
We haven’t had to use a non-linear decomposition in our interp work so far at Redwood. Just wanted to point out that it’s possible. I’m not sure when you would want to use one, but I haven’t thought about it that much.