We started out with the question: How does GPT-2 know when to use the word "an"
over "a"
? The choice depends on whether the word that comes after starts with a vowel or not, but GPT-2 can only output one word at a time.
We still don’t have a full answer, but we did find a single MLP neuron in GPT-2 Large that is crucial for predicting the token " an"
. And we also found that the weights of this neuron correspond with the embedding of the " an"
token, which led us to find other neurons that predict a specific token.
Discovering the Neuron
Choosing the prompt
It was surprisingly hard to think of a prompt where GPT-2 would output “ an”
(the leading space is part of the token) as the top prediction. Eventually we gave up with GPT-2_small
and switched to GPT-2_large
. As we’ll see later, even GPT-2_large
systematically under-predicts the token “ an”
. This may be because smaller language models lean on the higher frequency of " a"
to make a best guess. The prompt we finally found that gave a high (64%) probability for “ an”
was:
“I climbed up the pear tree and picked a pear. I climbed up the apple tree and picked”
The first sentence was necessary to push the model towards an indefinite article — without it the model would make other predictions such as “[picked] up”
.
Before we proceed, here’s a quick overview of the transformer architecture. Each attention block and MLP takes input and adds output to the residual stream.
Logit Lens
Using the logit lens technique, we took the logits from the residual stream between each layer and plotted the difference between logit(‘ an’)
and logit(‘ a’)
. We found a big spike after Layer 31’s MLP.
Activation Patching by the Layer
Activation patching is a technique introduced by Meng et. al. (2022) to analyze the significance of a single layer in a transformer. First, we saved the activation of each layer when running the original prompt through the model — the “clean activation”.
We then ran a corrupted prompt through the model:
“I climbed up the pear tree and picked a pear. I climbed up the lemon tree and picked”
By replacing the word "apple"
with "lemon"
, we induce the model to predict the token " a"
instead of " an"
. With the model predicting " a"
over " an"
, we can replace a layer’s corrupted activation with its clean activation to see how much the model shifts towards the " an"
token, which indicates that layer’s significance to predicting " an"
. We repeat this process over all the layers of the model.
We’re mostly going to ignore attention for the rest of this post, but these results indicate that Layer 26 is where " picked"
starts thinking a lot about " apple"
, which is obviously required to predict " an"
.
Note: the scale on these patching graphs is the relative logit difference recovery:
(ie. “what proportion of logit(" an") - logit(" a')
in the clean prompt did this patch recover?”).
The two MLP layers that stand out are Layer 0 and Layer 31. We already know that Layer 0’s MLP is generally important for GPT-2 to function (although we’re not sure why attention in Layer 0 is important).[1] The effect of Layer 31 is more interesting. Our results suggests that Layer 31’s MLP plays a significant role in predicting the " an"
token. (See this comment if you’re confused how this result fits with the logit lens above.)
Finding 1:
We can discover predictive neurons by activation patching individual neurons
Activation patching has been used to investigate transformers by the layer, but can we push this technique further and apply it to individual neurons? Since each MLP in a transformer only has one hidden layer, each neuron’s activation does not affect any other neuron in the MLP. So we should be able to patch individual neurons, because they are independent from each other in the same sense that the attention heads in a single layer are independent from each other.
We run neuron-wise activation patching for Layer 31’s MLP in a similar fashion to the layer-wise patching above. We reintroduce the clean activation of each neuron in the MLP when running the corrupted prompt through the model, and look at how much restoring each neuron contributes to the logit difference between " a"
and " an"
.
We see that patching Neuron 892 recovers 50% of the clean prompt’s logit difference, while patching whole layer actually does worse at 49%.
Finding 2: The activation of the “an-neuron” correlates with the “ an” token being predicted.
Neuroscope Layer 31 Neuron 892 Maximum Activating Examples
Neuroscope is an online tool that shows the top activating examples in a large dataset for each neuron in GPT-2. When we look at Layer 31 Neuron 892, we see that the neuron maximally activates on tokens where the subsequent token is " an"
.
But Neuroscope only shows us the top 20 most activating examples. Would there be a correlation for a wider range of activations?
Testing the neuron on a larger dataset
To check this, we ran the pile-10k dataset through the model. This is a diverse set of about 10 million tokens taken from The Pile, split into prompts of 1,024 tokens. We plotted the proportion of " an"
predictions across the range of neuron activations:
We see that the " an"
predictions increase as the neuron’s activation increases, to the point where " an"
is always the top prediction. The trend is somewhat noisy, which suggests that there might be other mechanisms in the model that also contribute towards the " an"
prediction. Or maybe when the " an"
logit increases, other logits increase at the same time.
Note that the model only predicted " an"
1,500 times, even though it actually occurred 12,000 times in the dataset. No wonder it was so hard to find a good prompt!
The neuron’s output weights have a high dot-product with the “ an” token
How does the neuron influence the model’s output? Well, the neuron’s output weights have a high dot product with the embedding for the token “ an”
. We call this the congruence of the neuron with the token. Compared to other random tokens like " any"
and " had"
, the neuron’s congruence with " an"
is very high:
In fact, when we calculate the neuron’s congruence with all of the tokens, there are a few clear outliers:
It seems like the neuron basically adds the embedding of “ an”
to the residual stream, which increases the output probability for “ an”
since the unembedding step consists of taking the dot product of the final residual with each token.[2]
Are there other neurons that are also congruent to “ an”
? To find out, we plotted the congruence of all neurons with the " an"
token:
Our neuron is way above the rest, but there are some other neurons with a fairly high congruence. These other neurons could be part of the reason why the correlation between the " an"
neuron’s activation and the prediction of the " an"
token isn’t perfect: there may be prompts where " an"
is predicted, but the model uses these other neurons to do it.
If this is the case, could we use congruence to find a neuron that is perfectly correlated with a single token prediction?
Finding 3: We can use neurons’ output congruence to find specific neurons that predict a token
Finding a token-associated neuron
We can try to find a neuron that is associated with a specific token by running the following search:
For each token, find the neuron with the highest output congruence.
For each of these neurons, find how much more congruent they are than the 2nd most congruent neuron for the same token.
Take the neuron(s) that are the most exclusively congruent.
With this search, we wanted to find neurons that were uniquely responsible for a token. Our conjecture was that these neurons’ activations would be more correlated with their tokens’ prediction, since any prediction of that token would “rely” on that neuron.
Let’s try running the “ though”
neuron — Layer 28 Neuron 1921 — through the dataset and see whether we get a cleaner graph.
Woah, that is much messier than the graph for the " an"
neuron. What is going on?
Looking at Neuroscope’s data for the neuron reveals that it predicts both the tokens “ though”
and “ however”
. This complicates things — it seems that this neuron is correlated with a group of semantically similar tokens (conjunctive adverbs).[3]
When we calculate the neuron’s congruence for all tokens, we find that the same tokens pop up as outliers:
In our large dataset correlation graph above, instances where the neuron activates and " however"
is predicted over " though"
would be counted as negative examples, since " though"
was not the top prediction. This could also explain some of the noise in the " an"
correlation, where the neuron is also congruent with "An"
, " An"
and "an"
.[4]
Can we find a simpler neuron to look at — preferably a neuron that only predicts for one token?
Finding a cleanly associated neuron
For a neuron to be ‘cleanly associated’ with a token, their congruence with each other should be mutually exclusive, meaning:
The neuron is much more congruent with the token than any other neuron.
The neuron is much more congruent with the token than any other token.
(Remember, ‘congruence’ is just our term for the dot product.)
Both criteria help to simplify the relationship between the neuron and its token. If a neuron’s congruence with a token is a representation of how much it contributes to that token’s prediction, the first criteria can be seen as making sure that only this neuron is responsible for predicting that token, while the second criteria can be seen as making sure that this neuron is responsible for predicting only that token.
Our search then is as follows:
For each token, find the most congruent neuron.
For each neuron, find the most congruent token.[5]
Find the token-neuron pairs that are on both lists — that is, the pairs where the neuron’s most congruent token is a token which is most congruent with that neuron!
Calculate how distinct they are by multiplying their top 2 token congruence difference with their top 2 neuron congruence difference.
Find the pairs with the highest mutual exclusive congruence.
For GPT-2_large
, Layer 33 Neuron 4142 paired with "i"
scores the highest on this metric. Looking at Neuroscope[6] confirms the connection:
And when we plot the graph of top prediction proportion over activation for the top 5 highest scorers:[7]
We do indeed see strong correlations for each pair!
What Does This All Mean?
Does the congruence of a neuron with a token actually measure the extent to which the neuron predicts that token? We don’t know. There could be several reasons why even token-neuron pairs with high mutual exclusive congruence may not always correlate:
The token could be also predicted by a combination of less congruent neurons
The token could be predicted by attention heads
Even if a neuron’s activation has a high correlation with a token’s logit, it may also indirectly correlate with other token’s logits, such that the neuron’s activation does not correlate with the token’s probability.
There may be later layers which add the opposite direction to the residual stream, cancelling the effect of a neuron.
However, we’ve found that the token neuron pairs with the top 5 highest mutual exclusive congruence do in fact have a strong correlation.
TL;DR
We used activation patching on a neuron level to find a neuron that’s important for predicting the token
" an"
in a specific prompt.The
“ an"
neuron activation correlates with" an"
being predicted in general.This may be because the neuron’s output weights have a high dot product with the
" an"
token (the neuron is highly congruent with the token). Moreover this neuron has a higher dot product with this token than any other token. And this neuron has a higher dot product with this token than the token has with any other neuron (they have high mutual exclusive congruence).The congruence between a neuron and a token is cool. We find the top 5 neuron-token pairs by mutual exclusive congruence. The activations of these neurons strongly correlate with the prediction of their respective tokens.
The code to reproduce our results can be found here.
This is a write-up and extension of our winning submission to Apart Research’s Mechanistic Interpretability Hackathon. Thanks to the London EA Hub for letting us use their co-working space, Fazl Barez for his comments and Neel Nanda for his feedback and for creating Neuroscope, the pile-10k dataset and TransformerLens.
- ^
Neel Nanda’s take on MLP 0:
”It’s often observed on GPT-2 Small that MLP0 matters a lot, and that ablating it utterly destroys performance. My current best guess is that the first MLP layer is essentially acting as an extension of the embedding (for whatever reason) and that when later layers want to access the input tokens they mostly read in the output of the first MLP layer, rather than the token embeddings. Within this frame, the first attention layer doesn’t do much.
In this framing, it makes sense that MLP0 matters on the second subject token, because that’s the one position with a different input token!
I’m not entirely sure why this happens, but I would guess that it’s because the embedding and unembedding matrices in GPT-2 Small are the same. This is pretty unprincipled, as the tasks of embedding and unembedding tokens are not inverses, but this is common practice, and plausibly models want to dedicate some parameters to overcoming this.
I only have suggestive evidence of this, and would love to see someone look into this properly!” - ^
What else could it have done? It might have suppressed the logit for
" a"
which would have had the same impact on the logit difference. Or it might have added some completely different direction to the residual which would cause a neuron in a later layer to increase the" an"
logit. - ^
Note that the
" though"
neuron is congruent to a group of semantically similar tokens, while the" an"
neuron is correlated with a group of syntactically similar tokens (eg." an"
and" Ancients"
). - ^
Why does
" an"
have a cleaner correlation despite the other congruent tokens? We’re not sure. One possible explanation is that"An"
and" An"
are simply much less common tokens so they make little impact on the correlation, while"an"
has a significantly lower congruence with the neuron than the top 3.In general, we expect that neurons found by only looking at the top 2 neuron difference for each token will not often have clean correlations with their respective tokens because these neurons may be congruent with multiple tokens.
- ^
When we look at the most congruent neuron for each token, we see some familiar troublemakers showing up with very high congruence:
At first, it looks like these ‘forbidden tokens’ are all associated with a ‘forbidden neuron’ (Layer 35 Neuron 3354) which they are all very congruent with. But actually if we plot the most congruent tokens of many other neurons we also see some of these weird tokens near the top. Our tentative hypothesis is that this has something to do with the hubness effect.
- ^
Neuroscope data wasn’t available for this neuron, so we took the max activating dataset examples from the pile-10k dataset. Texts 1, 2, 3 are prompts 1755, 8528 and 6375 respectively.
- ^
Note that one of the top 5 tokens is
"an"
, but this is different from" an"
that we were talking about earlier, and it will rarely be used as the start of a word or a word on its own. Similarly the neuron with which it is paired, Layer 34 Neuron 4549, is not the" an"
neuron named earlier.
I notice that the token in question happens to be segmented as “_an” and “_a” and not “_an_” or “_a_”.
So continuations like [“_a”,”moral”,”_fruit”] or [“_an”,”tagonist”,”ic”,”_monster”,”s”] could be possible (assuming those are all legal tokens).
I am reminded of the wonderful little nuggest in linguistics, where people are supposed to have said something like “a narange” (because that kind of fruit came from the spanish province of “naranja”). The details on these claims are often not well documented.
Relatively more scholarly analysis of such “junctural resegmentation” issues does exist though!
Something that might be tried is to look for neurons that are very likely to be active two or three steps before a given token is actually produced, which might represent a sort of “intent for the token to show up eventually”.
If you find such things, I might suggest that they be named “Wernicke Neurons”, and I wouldn’t be surprised if they ended up being highly/visibly related to “an intelligible semantic intent” in sketches of active speaking personas who probably want to eventually/potentially touch on >1 keywords.
The thing I would expect is that many related keywords would have prioritization before they are uttered, and they would hold each other somewhat in abeyance, with jostling and deal making amongst themselves, maybe with some equivalent to Broca’s Area as the judge of who should win when various Wernicke Neurons try to make their ideas “go out the door” first (or at all)?
Cool results!
A few questions:
The total logit diff between “a” and “an” contributed by layer 31 seems to be ~1.5 based on your logit lens figure, but your neuron only contributes ~0.4 -- do you have a sense of how exactly the remaining 1.1 is divided?
What’s going on with the negative logit lens layers?
Is there a reason you focus on output congruence as opposed to cosine similarity (which is just normalized congruence)? Intuitively, it seems like the scale of the output vector of an MLP neuron can relatively arbitrary (only constrained to some extent by weight decay), since you can always scale the input instead. Do you expect the results to be different if you used that metric instead?
(I’m guessing not because of the activation patching experiment)
Sorry we didn’t explain what the scales are on the figures! I’ve added a clarification in the post. The first graph is the absolute logit difference between
" a"
and" an"
. For each of the activation patching figures, the metric shown is the relative logit recovery:PatchedLogitDiff−CorruptedLogitDiffCleanLogitDiff−CorruptedLogitDiff
So 1 means the patch recovered the same logit diff as the clean prompt, 0 means patch didn’t change the corrupted prompt’s logit diff, <0 means the patch made the logit diff worse than the corrupted prompt etc.
We can see from the MLP layer patching figure that patching MLP 31 recovers 49% of the performance of the clean prompts (you can see the exact numbers on the interactive figures in the linkpost). And from the neuron patching figure we see that patching just Neuron 892 recovers 50% of the clean prompt performance, so actually the rest of the layer is entirely unhelpful.
The next question might be: “Why does patching MLP 31 only recover 49% of the performance when the logit lens makes it look like it’s doing all the work?” I’m not sure what the answer to this is but I also don’t think it’s particularly surprising. It may be that when running the corrupted activation, MLP 33 adds a bunch of the
" a"
logit to the residual, which patching MLP 31 doesn’t change very much.I think this just means that for the first 30 layers the model moves towards
" a"
being a better guess than" an"
. I expect a lot of the computation in these layers is working out that an indefinite article is required, of which" a"
is more likely a priori. Only at layer 31 does it realize that" an"
is actually more appropriate in this situation than" a"
.It seems to me like using cosine similarity could give different and misleading results. Imagine if
" an"
pointed in the exact same direction (cosine similarity =1) as two neurons. If one of the two neurons has magnitude 100× bigger than the other, then it will have 100× more impact on the" an"
logit.I don’t understand what you mean by “you can always scale the input instead”. And as input to the MLP is the LayerNorm of the residual up that point the magnitude of the input is always the same.
Thanks for clarifying the scales!
I might be misremembering the GPT2 architecture, but I thought the output of the MLP layer was something like WO⋅ReLU(Win⋅LN(xresidual))? So you can just scale Win up when you scale WOdown. (Assuming I’m remembering the architecture correctly,) if you’re concerned about the scale of the output, then I think it makes sense to look at the scale of Win as well.
We took dot product over cosine similarity because the dot product is the neuron’s effect on the logits (since we use the dot product of the residual stream and embedding matrix when unembedding).
I think your point on using the scale Win if we are concerned about the scale of Wout is fair — we didn’t really look at how the rest of the network interacted with this neuron through its input weights, but perhaps a input-scaled congruence score (e.g. output congruence * average of squared input weights) could give us a better representation of a neuron’s relevance for a token.
I do agree that looking at WO alone seems a bit misguided (unless we’re normalizing by looking at cosine similarity instead of dot product). However, the extent to which this is true is a bit unclear. Here are a few considerations:
At first blush, the thing you said is exactly right; scaling Win up and scale WO down will leave the implemented function unchanged.
However, this’ll affect the L2 regularization penalty. All else equal, we’d expect to see ∥Win∥=∥WO∥, since that minimizes the regularization penalty.
However, this is all complicated by the fact that you can also alternatively scale the LayerNorm’s gain parameter, which (I think) isn’t regularized.
Lastly, I believe GPT2 uses GELU, not ReLU? This is significant, since it no longer allows you to scale Win and WO without changing the implemented function.
Nice work, thanks for sharing! I really like the fact that the neurons seem to upweight different versions of the same token (
_an
,_An
,an
,An
, etc.). It’s curious because the semantics of these tokens can be quite different (compared to thethough
,tho
,however
neuron).Have you looked at all into what parts of the model feed into (some of) the cleanly associated neurons? It was probably out of scope for this but just curious.
One reason the neuron is congruent with multiple of the same tokens may be because those token embeddings are similar (you can test this by checking their cosine similarities).
Yup! I think that’d be quite interesting. Is there any work on characterizing the embedding space of GPT2?
Adam Scherlis did some preliminary exploration here:
https://www.lesswrong.com/posts/BMghmAxYxeSdAteDc/an-exploration-of-gpt-2-s-embedding-weights
Here’s a more thorough investigation of the overall shape of said embeddings with interactive figures:
https://bert-vs-gpt2.dbvis.de/
There’s also a lot of academic work on the geometry of LM embeddings, e.g.:
https://openreview.net/forum?id=xYGNO86OWDH (BERT, ERNIE)
https://arxiv.org/abs/2209.02535 (GPT-2-medium)
(Plus a mountain more on earlier text/token embeddings like Word2Vec.)
https://www.lesswrong.com/posts/aPeJE8bSo6rAFoLqg/solidgoldmagikarp-plus-prompt-generation is related to the embedding space
We did look very briefly at this for the
" an"
neuron. We plotted the residual stream congruence with the neuron input weights throughout the model. The second figure shows the difference from each layer.Unfortunately I can’t seem to comment an image. See it here.
We can’t tell that much from this but I think there are three takeaways:
The model doesn’t start ‘preparing’ to activate the
" an"
neuron until layer 16.No single layer stands out a lot as being particularly responsible for the
" an"
neuron’s activation (which is part of why we didn’t investigate this further).The congruence increases a lot after MLP 31. This means the output of layer 31 is very congruent with the input weights of the
" an"
neuron (which is in MLP 31). I this this is almost entirely the effect of the" an"
neuron, partly because the input of the" an"
neuron is very congruent with the" an"
token (although not as much as the neuron output weights). This makes me think that this neuron is at least partly a ‘signal boosting’ neuron.The first time I saw this post, I didn’t understand. Then I had an insight, and now I get a pun
For clarifying my own understanding:
The dot product of the row of a neuron’s weight vector (ie a row in W_out) with the unembedding matrix (in this case the embedding.T because GPT is tied embeddings) is what directly contributes to the logit outputs.
If the neuron activation is relatively very high, then this swamps the direction of your activations. So, artificially increasing W_in’s neurons to eg 100 should cause the same token to be predicted regardless of the prompt.
This means that neuron A could be more congruent than neuron B, but B contribute more to the logits of their token simply because B is activated more.
This is useful for mapping features to specific neurons if those features can be described as using a single token (like “ an”). I’d like to think more later about finding neurons for groups of speech, like a character’s catch phrase.
This seems all correct to me except possibly this:
W_in is the input weights for each neuron. So you could increase the activation of the
" an"
neuron by multiplying the input weights of that neuron by 100. (ie. Win.T[892]*=100.)And if you increase the
" an"
neuron’s activation you will increase" an"
’s logit. Our data suggests that if the activation is >10 then it will almost always be the top prediction.I think this is true but not necessarily relevant. On the one hand, this neuron’s activation will increase the logit of
" an"
regardless of what the other activations are. On the other hand if the other activations are high then this may reduce the probability of" an"
by either increasing other logits or activating other neurons in later layers that output the opposite direction to" an"
to the residual stream.The LessWrong Review runs every year to select the posts that have most stood the test of time. This post is not yet eligible for review, but will be at the end of 2024. The top fifty or so posts are featured prominently on the site throughout the year.
Hopefully, the review is better than karma at judging enduring value. If we have accurate prediction markets on the review results, maybe we can have better incentives on LessWrong today. Will this post make the top fifty?
Hello! A great write-up and fascinating investigation. Well done with such a great result from a hackathon.
I’m trying to understand your plot titled ‘Proportion of Top Predictions that are ” an” by Layer 31 Neuron 892 Activation’. Can you explain what the y-axis is in this plot? It’s not clear what the y-axis is a proportion of.
I read through the code, but couldn’t quite follow the logic for this plot. It seems that the y-axis is computed with these lines;
But I’m not sure what the numerator
v["count"]
from withinlogit_bins
corresponds to.Thank you :)
Aaron
Hi!
For each token prediction we record the activation of the neuron and whether on not ” an” has a greater logit than any other token (if it was the top prediction).
We group the activations into buckets of width 0.2. For each bucket we plot
Number of times ‘‘ an" was the top prediction (for activations in this bucket)Number of activations in this bucket
Does that clarify things for you?
How was the ′ a’ v. ′ an’ selection task selected? It seems quite convenient to probe for and also the kind of thing that could result from p-hacking over a set of similar simple tasks.
The prompt was in a style similar to the [Interpretability In The Wild](https://arxiv.org/abs/2211.00593) paper, where one token (′ an’) would be the top answer for the pre-patched prompt — the one with ‘apple’, and the other token (′ a’) would be the the top answer for the patched prompt — the one with ‘lemon’. The idea is that with these prompts is that we know that the top prediction is either ′ an’ or ′ a’, and we can measure the effect of each individual part of the model by seeing how much patching that part of the model sways the prediction towards the ′ a’ token.
To be clear, this can only tell us the significance of this neuron in this particular prompt, which is why we also tried to look at the behaviour of this neuron through other perspectives — which was looking at its activation over a larger, diverse dataset, and looking at its output weights.
Thanks, but I’m asking more about why you chose to study this particular thing instead of something else entirely. For example, why not study “this” versus “that” completions or any number of other simple things in the language model?
I don’t think there was much reason for choosing
" a"
vs." an"
to study over something else. This was the first thing we investigated and we were excited to see a single neuron mechanism, so we kept going. Bear in mind this project originated in a 48 hour hackathon :)