How-to Transformer Mechanistic Interpretability—in 50 lines of code or less!
Produced as part of the SERI ML Alignment Theory Scholars Program—Winter 2022 Cohort.
What if I told you that in just one weekend you can get up to speed doing practical Mechanistic Interpretability research on Transformers? Surprised? Then this is your tutorial!
I’ll give you a view to how I research Transformer circuits in practice, show you the tools you need, and explain my thought process along the way. I focus on the practical side to get started with interventions; for more background see point 2 below.
Prerequisites:
Understanding the Transformer architecture: Know what the residual stream is, how attention layers and MLPs work, and how logits & predictions work. For future sections familiarity with multi-head attention is useful. Here’s a link to Neel’s glossary which provides excellent explanations for most terms I might use!
If you’re not familiar with Transformers you can check out Step 2 (6) on Neel’s guide or any of the other explanations online, I recommend Jay Alammar’s The Illustrated Transformer and/or Milan Straka’s lecture series.Some overview of Mechanistic Interpretability is helpful: See e.g. any of Neel’s talks, or look at the results in the IOI paper / walkthrough.
Basic Python: Familiarity with arrays (as in NumPy or PyTorch, for indices) is useful; but explicitly no PyTorch knowledge required!
No hardware required, free Google Colab account works fine for this. Here’s a notebook with all the code from this tutorial! PS: Here’s a little web page where you can run some of these methods online! No trivial inconveniences!
Step 0: Setup
Open a notebook (e.g. Colab) and install Neel Nanda’s TransformerLens (formerly known as EasyTransformer).
!pip install transformer_lens
Step 1: Getting a model to play with
from transformer_lens import HookedTransformer, utils
model = HookedTransformer.from_pretrained("gpt2-small")
That’s it, now you’ve got a GPT2 model to play with! TransformerLens supports most relevant open source transformers. Here’s how to run the language model
logits = model("Famous computer scientist Alan")
# The logit dimensions are: [batch, position, vocab]
next_token_logits = logits[0, -1]
next_token_prediction = next_token_logits.argmax()
next_word_prediction = model.tokenizer.decode(next_token_prediction)
print(next_word_prediction)
Let’s have a look at the internal activations: TransformerLens can give you a dictionary with almost all internal activations you ever care about (referred to as “cache”):
logits, cache = model.run_with_cache("Famous computer scientist Alan")
for key, value in cache.items():
print(key, value.shape)
Here you will find things like the attention pattern blocks.0.attn.hook_pattern
, the residual stream before and after each layer blocks.1.hook_resid_pre
, and more!
You can also access all the weights & parameters of the model in model.named_parameters()
. Here you will find weight matrices & biases of every MLP and Attention layer, as well as the embedding & unembedding matrices. I won’t focus on these in this guide but they’re great to look at! (Exercise: What can the unembedding biases unembed.b_U
tell you about common tokens?)
Step 2: Let’s start analyzing a behavior!
Let’s go and find some induction heads! I’ll make up an example: Her name was Alex Hart. When Alex
, with likely completion Hart
. TransformerLens has a little tool to plot a tokenized prompt, model predictions, and associated logits: [1]
utils.test_prompt("Her name was Alex Hart. Tomorrow at lunch time Alex",
" Hart", model)
I find it is useful to spend a few minutes thinking about which information is needed to solve the task: The model needs to
Realize the last token,
Alex
, is a repetition of a previous occurrenceThe model needs to copy the last name from after the previous
Alex
occurrence to the last token as prediction
Method 1: Residual stream patching
The number 1 thing I try when I want to reverse engineer a new behavior is to find where in the network the information is “traveling”.
In transformers, the model keeps track of all information in the residual stream. Attention heads & MLPs read from the residual stream, perform some computation or information moving, and write their outputs back into the residual stream. I think of this stream as having a couple of “lanes” corresponding to each token position. Over the course of the model (i.e. between the start at layer 0 til the end at the final layer) the information has to travel from “Alex Hart” to the rightmost lane (final token), which predicts the next token).
We can view this movement by comparing the residual stream between subtly different prompts that contain slightly different information. [2] Based on the three pieces of information mentioned above (last name, and 1st and 2nd occurrence of first name) we can come up with 3 variations (“corruptions”) of the baseline (“clean”) prompt: [3]
[Baseline] Her name was Alex Hart. When Alex
Her name was Alex Carroll. When Alex
Her name was Sarah Hart. When Alex
Her name was Alex Hart. When Sarah
Okay so let’s say we want to figure out where the transformer stores information related to the last name Hart
, and where this information travels. We can compare prompt 1 (clean) and 2 (corrupt), so any difference in the residual stream between these two examples must be caused by the last name!
My introductory explanation
I promise, it’s really reasonably easy to do this! Most things in TransformerLens are based on “hooks”: Hooks allow us to open up almost any part of a transformer and manipulate the model activations in any way we want!
So let’s compare the residual stream between these two runs. Staring at 768 dimensional vectors isn’t everyone’s cup of tea but there’s a simpler way:
We can “patch in” the residual stream from the corrupted run (where the last name is “Carroll”) into a run with the original names at the position in question and see if the result changes (or vice versa[4]).
Start by saving the activations of the corrupted run into a cache (basically a dictionary):
_, corrupt_cache = model.run_with_cache("Her name was Alex Carroll. Tomorrow at lunch time Alex")
Then we write a patch that overwrites the residual stream with the one we saved. Let’s say we patch the residual stream at position 5 – this is the token Hart
– and between layer 6 and 7 (“post 6”):
def patch_residual_stream(activations, hook, layer="blocks.6.hook_resid_post", pos=5):
# The residual stream dimensions are [batch, position, d_embed]
activations[:, pos, :] = corrupt_cache[layer][:, pos, :]
return activations
# add_hook takes 2 args: Where to insert the patch,
# and the function providing the updated activations
model.add_hook("blocks.6.hook_resid_post", patch_residual_stream)
This is how you use hooks: You insert a function at the point specified by the 1st argument of add_hook
, in this case just after layer (aka block) 6 (zero-indexed): blocks.6.hook_resid_post
. [5] Now in that function you can do anything with the internal activations at this layer, it’s just a big array!
What we decide to do here is to overwrite the slice corresponding to position 5 with the alternative residual stream from the cache. The function returns [6] the updated values and the model will just continue its run with this modified residual stream.
utils.test_prompt("Her name was Alex Hart. Tomorrow at lunch time Alex", " Hart", model)
If we test the model now, we will see that the preferred answer is changed to “Carroll”. So the residual stream at position 5 in layer 6 indeed contained crucial last name information!
Keep in mind the relevant token positions for later: pos5
is the 1st first name, pos6
is the last name, and the final token at pos11
is the 2nd first name.
PS: Any time you use model.add_hook()
you always need to call model.reset_hooks()
to remove the hook(s) again.
model.reset_hooks()
Now, this is a bit of a pain to do manually. So here is how I would do this in my day-to-day research.
How I do this in practice
This is the same as above, with three main improvements you probably have already thought of:
We run a loop over all layers and positions, and show the result as image [for any plotting code see my notebook]
We directly use the logit difference (“logit diff” for short) [7] between the different answers to measure change rather than manually looking at
utils.test_prompt
We use the
model.run_with_hooks
function to skip manually adding & resetting hooks.
import torch
from functools import partial
# Clean and corrupt prompts in variables
clean_prompt = "Her name was Alex Hart. Tomorrow at lunch time Alex"
corrupt_prompt = "Her name was Alex Carroll. Tomorrow at lunch time Alex"
# Get the list of tokens the model will deal with
clean_tokens = model.to_str_tokens(clean_prompt)
# Indices of the right and wrong answers (last names) to judge what the model predicts
_, corrupt_cache = model.run_with_cache(corrupt_prompt)
# List of layers and positions to iterate over. We want to patch before the
# first layer, and after every layer (so we cover 13 positions in total).
layers = ["blocks.0.hook_resid_pre", *[f"blocks.{i}.hook_resid_post" for i in range(model.cfg.n_layers)]]
n_layers = len(layers)
n_pos = len(clean_tokens)
# Indices of the right and wrong predictions to automatically measure performance
clean_answer_index = model.tokenizer.encode(" Hart")[0]
corrupt_answer_index = model.tokenizer.encode(" Carroll")[0]
# Test the effect of patching at any layer and any position
patching_effect = torch.zeros(n_layers, n_pos)
for l, layer in enumerate(layers):
for pos in range(n_pos):
fwd_hooks = [(layer, partial(patch_residual_stream, layer=layer, pos=pos))]
prediction_logits = model.run_with_hooks(clean_prompt,
fwd_hooks=fwd_hooks)[0, -1]
patching_effect[l, pos] = prediction_logits[clean_answer_index] \
- prediction_logits[corrupt_answer_index]
So we see that the last name information (Hart vs Carroll) jumps from position 5 to 11, mostly around Layer 10 but also before & after that layer. Now we repeat the same exercise for the 1st and 2nd occurrences of the first name:
corrupt_prompt = "Her name was Sarah Hart. Tomorrow at lunch time Alex"
corrupt_prompt = "Her name was Alex Hart. Tomorrow at lunch time Sarah"
This plot looks interesting! So apparently the literal first name affects not only position 4 (the actual first name token) but in layer 4 the information travels to position 5! So the model stores the information about the first name at the position of the last name token! Around layers 5 and 6 the information seems to travel to the final token then. After layer 7, the first name patches anywhere left of position 11 does no longer matter! By this point the model has already figured out what it needs to do in the end (copy the last name) and doesn’t care about the rest!
A good intuition to keep in mind: Information in this image can only travel rightward, since the model may only access earlier tokens, and downward, since all outputs accumulate in the residual stream layer for layer. This is a useful intuition to keep in mind when guessing how a transformer does a thing!
You might have noticed the patching color switched to white / light blue rather than red: This is because patching a mismatched first name doesn’t flip the prediction but just leaves the model without ways to deduce the last name, thus not preferring either.
Okay this one looks trivial, and in hindsight this is intuitive: We said information can only travel rightward, so patching at pos11
cannot affect anything other than the rightmost column.
A second intuition is that patching the residual stream at all positions is just like running the corrupt prompt, there is no way for the model to distinguish this the residual stream is the only source of information.
So in this case patching pos11
at any layer just gives you the result of the corrupted run, thus always the same logit diff.
In the first two plots we saw that the behavior is kind-of smeared over multiple layers, and this is typical if your model is much larger than needed for a task. I guess that, in addition to the bare minimum algorithm, the model adds an ensemble of slight variations of this algorithm to do better in various situations. Or it is because of training with dropouts. Anyway, to make our life simpler let’s continue with a simpler 2 layer model (“gelu-2l”). It can just about do this task, and here are the equivalent patching plots:
The situation is much clearer in the two layer model (which is still really good at this task!): The first name information (upper plot) moves to pos5
in layer 0, and then moves to pos11
in layer 1. The last name information (lower plot) moves from pos5
to pos11
in layer 1.
Interlude: Overview of attention head hooks
Here’s a brief overview of the internal activations we can access in the attention module with TransformerLens:
blocks.0.attn.hook_q
(as well ask
andv
): The Q, K, and V vectors of each attention head in a block / layer (block 0 here). Dimensions are[batch, position, head, d_head]
. [8] So e.g.blocks.0.attn.hook_k[0, 5, 3, :]
gives us the K vector of the 4th attention head (index 3, 0-indexed) for the first item of a batch (if you are not using batches – like in this tutorial – the batch index is just always 0), for the 6th token (index 5).For example let’s calculate the score attending from (query) Hart to (key) Alex for head 1 in layer 0:
k = cache["blocks.0.attn.hook_k"][0, 4, 1, :]
q = cache["blocks.0.attn.hook_q"][0, 5, 1, :]
k @ q / math.sqrt(model.cfg.d_head)
blocks.0.attn.hook_attn_scores[batch, head, dst, src]
gives you the attention scores for a head. The last two dimensions are the query position (dst
) and key position (src
). We attend fromdst
tosrc
. [9] This is the same number we calculated above:cache["blocks.0.attn.hook_attn_scores"][0, 1, 5, 4]
blocks.0.attn.hook_pattern[batch, head, dst, src]
is the final attention pattern, which is just the softmax of the previous values (applied to the last dimension).blocks.0.attn.hook_z[batch, position, head, d_head]
is the Z vector, i.e. the sum of V vectors of all previous tokens weighted by the attention pattern.We can also reproduce this from the previous value as
weighting = cache["blocks.0.attn.hook_pattern"][0, 1, 5, :]
v = cache["blocks.0.attn.hook_v"][0, :, 1, :]
weighting @ v == cache["blocks.0.attn.hook_z"][0, 5, 1, :]
blocks.0.hook_attn_out[batch, position, d_embed]
contains the attention layer output, this is every head’s output matrix W_O applied to every head’s Z vector, and then summed. You might notice we’re missing that intermediate step which would give you the (very useful) residual stream output of every individual head—usually this is skipped for performance reasons but we can enable it withmodel.cfg.use_attn_result = True
to get:blocks.0.attn.hook_result[batch, position, head, d_embed]
which gives you the output that each head writes into the residual stream at each position! This is what we need for my 2nd favorite patching method:
Method 2: Attention head patching
In Method 1: Residual stream patching we saw that the model is doing something in layer 1 that makes the information move to position 11. There might be quite a few heads per layer, so can we look at specifically which of the heads is doing this, before we look at particular attention patterns? Yes, of course we can!
We compare the residual stream outputs of every head using blocks.0.attn.hook_result
[10], and observe which of these have an effect on the model’s performance, pretty similar setup to above (showing the new parts here, full code in notebook):
model.cfg.use_attn_result = True
def patch_head_result(activations, hook, layer=None, head=None):
activations[:, :, head, :] = corrupt_cache[hook.name][:, :, head, :]
return activations
patching_effect = torch.zeros(n_layers, n_heads)
for layer in range(n_layers):
for head in range(n_heads):
fwd_hooks = [(f”blocks.{layer}.attn.hook_result”,
partial(patch_head_result, layer=layer, head=head))]
prediction_logits = model.run_with_hooks(clean_prompt,
fwd_hooks=fwd_hooks)[0, −1]
patching_effect[layer, head] = prediction_logits[clean_answer_index]
- prediction_logits[corrupt_answer_index]
So it appears that head 1.6 (it is common to denote attention heads in the “layer.head” format) is very important, and to a lesser extent heads 0.1, 0.2, and 1.1. You could do and look at these heads’ attention patterns, but personally I prefer to add an extra step first that will help us understand what the heads do:
Bonus: Attention head patching by position
In patch_head_result
we patch each attention head’s output at all positions – but we could just look at the result for each position separately to see where each head is relevant! It’s really just the same as above with an extra for loop:
def patch_head_result(activations, hook, layer=None, head=None, pos=None):
activations[:, pos, head, :] = corrupt_cache[hook.name][:, pos, head, :]
return activations
patching_effect = torch.zeros(n_layers*n_heads, n_pos)
for layer in range(n_layers):
for head in range(n_heads):
for pos in range(n_pos):
fwd_hooks = [(
f”blocks.{layer}.attn.hook_result”,
partial(patch_head_result, layer=layer, head=head, pos=pos)
)]
prediction_logits = model.run_with_hooks(clean_prompt,
fwd_hooks=fwd_hooks)[0, −1]
patching_effect[n_heads*layer+head, pos] = \
prediction_logits[clean_answer_index] \
- prediction_logits[corrupt_answer_index]
This is a 3 dimensional data set so let me combine layers and heads into a layer.head y-axis, and have token positions on the x-axis again (plotting code in notebook):
Look, isn’t this awesome? We can exactly pinpoint where which part of the model is doing what! So heads 1.1 and mostly 1.6 are responsible for writing the last name into pos11
(i.e. the prediction position), head 0.2 writes first name information into pos5
(last name), and heads 0.1 and 1.6 write something into pos11
that breaks if we change either occurrence of the first name.
Side-note: A clearer visualization is to plot the difference between the clean and patched logit diff, the logit difference difference if you will. Taking the same plot as above but subtracting the baseline logit diff gives this plot:
Interlude: Guessing the answers
I find it instructive to try and think like a transformer, think how you could possibly solve this task. What information do you need and what’s the shortest path to the answer? So we have a sentence Alex (pos4
) Hart (pos5
) …. Alex (pos11
), and the output at pos11
should predict the next token.
It’s fundamentally impossible to solve this in a single layer. All you have at the last position is the first name Alex, and there is no way to get to (read attend to) the last name without additional information (which would need another layer). Every step of serial computation basically requires a layer.
So let’s think about doing this in 2 layers. There are two things that need to happen:
You need to somehow connect Alex and Hart. How do we know that Hart is Alex’ last name? Because they’re next to each other. So one step must use this fact and create that connection.
Thinking like a transformer, information can only travel “rightwards” we need to write “Alex vibes” into the Hart token position (not overwrite, just write into some subspace of the residual stream at this position). This needs to happen here (layer 0, position Hart), and you could either write the token or positional embedding (let me call this “Alex vibes” or “pos4 vibes”) here. We tend to call such a head a “Previous Token Head”.Next you need to connect the 2nd Alex occurrence to this. The key bit of information that we need to use is the token being the same! I can think of two ways to do this:
The first option works simply in layer 1: Have a head that looks for “current-token-vibes” and copies the value from there. I.e. at the last position the head would look for, and find, “Alex vibes” in the Hart token position, and just copy the Hart token. Done!
Thinking of a second option, is there any way a head knows it can look for “pos4 vibes” to find the answer? Yes! You can have a layer 0 head that looks for duplicates of the current token, and copies their position! So such a “Duplicate Token Head” in layer 0 at pos11 notices Alex occurred previously at pos4 and writes “hey this is like pos4” into the residual stream at pos11.
Then a layer 1 head can read “hey this is like pos4” and look for “pos4 vibes”, finding those at the Hart token position (as written by the Previous Token Head), and copy the Hart token. Done again!
PS: When thinking like transformers I tend to keep a picture in my mind of what information is stored at which position. Callum McDougall made an amazing graphic that matches pretty well with my intuition, check it out!
Let’s make some guesses which heads are these, based on the patching plots:
The Previous Token Head must write into
pos5
and its output changes only if we patch from a prompt where the 1st first name is different (middle picture) → Only 0.2 matches this!The Duplicate Token head must write into
pos11
and its output would change if the 1st and 2nd first name are not identical (middle and right picture) → 0.1 matches this!The final head (in both hypotheses) would break (no useful output, i.e. approx 0 logit diff, white) when the first names don’t match, and actively suggest the wrong name (negative logit diff, red) if we patch from the different last name-run → 1.6 matches this well
Method 3: Looking at attention patterns
So to actually check what the heads do, or at least where they attend to, we can look at the attention patterns cache["blocks.0.attn.hook_pattern"]
for every head. There’s actually a neat tool that gives us a interactive view of the patterns for every head and every combination of positions, Circuitsvis: It’s really easy to use
!pip install circuitsvis
import circuitsvis as cv
prompt = "Her name was Alex Hart. Tomorrow at lunch time Alex"
_, cache = model.run_with_cache(prompt)
cv.attention.attention_patterns(tokens=model.to_str_tokens(prompt),
attention=cache['blocks.0.attn.hook_pattern'][0])
Our layer 0 heads (0.1 and 0.2) look, as expected, like Duplicate Token Head and Previous Token Head:
For layer 1 (change key to blocks.1) we see 1.6 attending to the last name position as expected! I can’t recognize what 1.1 is doing.
Conclusion
I think we can be pretty certain that we have a simple induction circuit (sometimes called induction “head”) consisting of the previous token head (PTH) 0.2 and copying head 1.6. There probably is a second variant based on the duplicate token (DTH) head 0.1 although I would like to confirm this by running patching experiments with the positional embedding to be sure.
Another experiment to check these hypotheses is computing composition scores between the different heads, and checking whether 0.2 K-composes with 1.6, and 0.1 Q-composes with 1.6. But this will have to wait until the next post.
PS: I keep a doc with common Transformer(Lens) gotchas I fell for in the past, mostly featuring the tokenizer, hooks, and the tokenizer.
What now?
Read more about Transformer Mech Int? See Neel’s overview here!
Get started by trying these methods on a different example, or pick one of Neel’s list of Concrete Open Problems (these are a bit harder)!
Sign-up for the next Alignment Jam Hackathon!
New to AI Alignment? Check out this intro course (101) (201)!
Feedback: I would love to hear all feedback on this post! To what degree you found this tutorial useful, where you got stuck, suggested improvements, and anything else!
Changelog: 27.01.2023: I originally said I used an attention-only 2 layer model while the notebook was actually using a normal 2 layer model (with MLPs, “gelu-2l”); fixed the typo now. Also added the logit diff difference plot.
- ^
Note the space in the answer token
| Hart|
— we transform a prompt into a sequence of tokens, and those don’t necessarily align with words. E.g. “ ethernet” splits into tokens| ether|
and|net|
. Most simple words are single tokens, but do include a leading space (such as| Hart|
. This makes sense as it allows the model to encode a 10 word sentence in ~10 tokens rather than 10 word tokens + 9 spaces, but frequently annoys mechanistic interpretability researchers. - ^
There are also methods not based on patching such as Direct Logit Attribution or Ablation. I won’t cover them (yet), and note that I rarely use (zero- or mean-) ablation methods which can mess up your model and produce weird results.
- ^
When doing this try to make sure the replacement words are the same number of tokens, otherwise you have to carefully line everything up. Check with
model.str_to_tokens
(see below). - ^
The naming convention is tending towards “activation patching” for either direction, “resampling ablation” for overwriting activations in the clean run with corrupt-run activations, and “causal tracing” for overwriting activations in a corrupted run with clean-run activations.
- ^
For a list of all possible things you can patch (hook points) look at the keys of the cache dict!
- ^
Note that
activations[:, pos, :] = …
directly edits the activations and the return is not strictly necessary, but I tend to leave it there since it is necessary in cases likeactivations = …
. - ^
Note that measuring the logit difference between Hart and Carroll is somewhat unprincipled since Carroll might not be involved at all; in practice we usually consider all symmetrical combinations. Other sensible and common choices are probability or logprob of the right answer, or discrete measures such as top-k success rate if testing on a batch of examples
- ^
d_embed
andd_head
are the embedding dimensions used in the residual stream and the attention heads, respectively. The latter is significantly smaller, e.g. 64 compared to 768 dimensions (GPT2-small). - ^
There is an annoying ambiguity in naming. We tend to say the model attends from a later token to an earlier one, but it moves information to a later token from an earlier one. People frequently use the later framing calling the earlier token the source (src) of information moved to the later token = destination (dst). To be safe you can always refer to the later token (dst) as query side and earlier one as key side. And if you notice your attn scores are −100000 then you probably mixed up the indices :)
- ^
You could also use
hook_attn_out
here, but I want to useattn.hook_result
later anyway
- 19 Nov 2024 19:45 UTC; 39 points) 's comment on StefanHex’s Shortform by (
- Mechanistic Interpretability Demo by 12 Feb 2023 17:10 UTC; 6 points) (EA Forum;
Great post! Thanks for making it :)
Note that
pip install transformer_lens
now works!Awesome, updated!
Wonderful post! Thank you for sharing your walkthrough.
I’m preparing for SERI MATS and I found this immensely helpful. Thanks a lot!