TLDR: The model ignores weird tokens when learning the embedding, and never predicts them in the output. In GPT-3 this means the model breaks a bit when a weird token is in the input, and will refuse to ever output it because it’s hard coded the frequency statistics, and it’s “repeat this token” circuits don’t work on tokens it never needed to learn it for. In GPT-2, unlike GPT-3, embeddings are tied, meaningW_U = W_E.T, which explains much of the weird shit you see, because this is actually behaviour in the unembedding not the embedding (weird tokens never come up in the text, and should never be predicted, so there’s a lot of gradient signal in the unembed, zero in the embed).
In particular, I think that your clustering results are an artefact of how GPT-2 was trained and do not generalise to GPT-3
Fun results! A key detail that helps explain these results is that in GPT-2 the embedding and unembedding are tied, meaning that the linear map from the final residual stream to the output logits logits = final_residual @ W_U is the transpose of the embedding matrix, ie W_U = W_E.T, where W_E[token_index] is the embedding of that token. But I believe that GPT-3 was not trained with tied embeddings, so will have very different phenomena here.
My mental model for what’s going on:
Let’s consider the case of untied embeddings first, so GPT-3:
For some stupid reason, the tokenizer has some irrelevant tokens that never occur in the training data. Your guesses seem reasonable here.
In OpenWebText, there’s 99 tokens in GPT-2′s tokenizer that never occur, and a bunch that are crazy niche, like ′ petertodd’
Embed: Because these are never in the training data, the model completely doesn’t care about their embedding, and never changes them (or, if they occur very rarely, it does some random jank). This means they remain close to their random initialisation
Models are trained with weight decay, which incentivises these to be set to zero, but I believe that weight decay doesn’t apply to the embeddings
Models are not used to having tokens deleted from their inputs, and so deleting this breaks things, which isn’t that surprising.
OTOH, if they genuinely do normalise to norm 1 (for some reason), the tokens are probably just embedding to a weird bit of embedding space that the model doesn’t expect. I imagine this will still break things, but it might just let the model confuse it with a token that happens to be nearby? I don’t have great intuitions here
Unembed: Because these are never in the training data, the model wants to never predict them, ie have big negative logits. The two easiest ways to do this are to give them trivial weights and a big negative bias term, or big weights and align them with a bias direction in final residual stream space (ie, a direction that always has a high positive component, so it can be treated as approx constant).
Either way, the observed effect is that the model will never predict them, which totally matches what you see.
As a cute demonstration of this, we can plot a scatter graph of log(freq_in_openwebtext+1) against unembed bias (which comes from the folded layernorm bias) coloured by the centered norm of the token embedding. We see that the unembed bias is mostly used to give frequency, but that at the tail end of rare tokens, some have tiny unembed norm and big negative bias, and others have high unembed norm and a less negative bias.
The case of tied embeddings is messier, because the model wants to do these two very different things at once! But since, again, it doesn’t care about the embedding at all (it’s not that it wants the token’s embedding to be close to zero, it’s that there’s never an incentive to update the gradients). So the effect will be dominated by what the unembed wants, which is getting their logits close to zero.
The unembed doesn’t care about the average token embedding, since adding a constant to every logit does nothing. The model wants a non-trivial average token embedding to use as a bias term (probably), so there’ll be a non-trivial average token embedding (as we see), but it’s boring and not relevant.
So the model’s embedding for the weird tokens will be optimised for giving a big negative logit in the unembedding, which is a weird and unnatural thing to do, and I expect is the seed of your weird results.
One important-ish caveat is that the unembed isn’t quite the transpose of the embed. There’s a LayerNorm immediately before the unembed, whose scale weights get folded into W_E.T to create an effective unembed (ie W_U_effective = w[:, None] * W_E.T), which breaks symmetry a bit. Hilariously, the model is totally accounting for this—if you plot norm of unembed and norm of embed against each other for each token, they track each other pretty well, except for the stupid rare tokens, which go wildly off the side.
Honestly, I’m most surprised that GPT-3 uses the same tokenizer as GPT-2! There’s a lot of random jank in there, and I’m surprised they didn’t change it.
Another fun fact about tokenizers (god I hate tokenizers) is that they’re formed recursively by finding the most common pair of existing tokens and merging those into a new token. Which means that if you get eg common triples like ABC, but never AB followed by not C, you’ll add in token AB, and then token ABC, and retain the vestigial token AB, which could also create the stupid token behaviour. Eg ” The Nitrome” is token 42,089 in GPT-2 and ” TheNitromeFan” is token 42,090, not that either actually come up in OpenWebText!
To check this, you’d want to look at a model trained with untied embeddings. Sadly, all the ones I’m aware of (Eleuther’s Pythia, and my interpretability friendly models) were trained on the GPT-NeoX tokenizer or variants, whcih doesn’t seem to have stupid tokens in the same way.
To check this, you’d want to look at a model trained with untied embeddings. Sadly, all the ones I’m aware of (Eleuther’s Pythia, and my interpretability friendly models) were trained on the GPT-NeoX tokenizer or variants, whcih doesn’t seem to have stupid tokens in the same way.
GPT-J uses the GPT-2 tokenizer and has untied embeddings.
TLDR: The model ignores weird tokens when learning the embedding, and never predicts them in the output. In GPT-3 this means the model breaks a bit when a weird token is in the input, and will refuse to ever output it because it’s hard coded the frequency statistics, and it’s “repeat this token” circuits don’t work on tokens it never needed to learn it for. In GPT-2, unlike GPT-3, embeddings are tied, meaning
W_U = W_E.T
, which explains much of the weird shit you see, because this is actually behaviour in the unembedding not the embedding (weird tokens never come up in the text, and should never be predicted, so there’s a lot of gradient signal in the unembed, zero in the embed).In particular, I think that your clustering results are an artefact of how GPT-2 was trained and do not generalise to GPT-3
Fun results! A key detail that helps explain these results is that in GPT-2 the embedding and unembedding are tied, meaning that the linear map from the final residual stream to the output logits
logits = final_residual @ W_U
is the transpose of the embedding matrix, ieW_U = W_E.T
, whereW_E[token_index]
is the embedding of that token. But I believe that GPT-3 was not trained with tied embeddings, so will have very different phenomena here.My mental model for what’s going on:
Let’s consider the case of untied embeddings first, so GPT-3:
For some stupid reason, the tokenizer has some irrelevant tokens that never occur in the training data. Your guesses seem reasonable here.
In OpenWebText, there’s 99 tokens in GPT-2′s tokenizer that never occur, and a bunch that are crazy niche, like ′ petertodd’
Embed: Because these are never in the training data, the model completely doesn’t care about their embedding, and never changes them (or, if they occur very rarely, it does some random jank). This means they remain close to their random initialisation
Models are trained with weight decay, which incentivises these to be set to zero, but I believe that weight decay doesn’t apply to the embeddings
Models are not used to having tokens deleted from their inputs, and so deleting this breaks things, which isn’t that surprising.
OTOH, if they genuinely do normalise to norm 1 (for some reason), the tokens are probably just embedding to a weird bit of embedding space that the model doesn’t expect. I imagine this will still break things, but it might just let the model confuse it with a token that happens to be nearby? I don’t have great intuitions here
Unembed: Because these are never in the training data, the model wants to never predict them, ie have big negative logits. The two easiest ways to do this are to give them trivial weights and a big negative bias term, or big weights and align them with a bias direction in final residual stream space (ie, a direction that always has a high positive component, so it can be treated as approx constant).
Either way, the observed effect is that the model will never predict them, which totally matches what you see.
As a cute demonstration of this, we can plot a scatter graph of
log(freq_in_openwebtext+1)
againstunembed bias
(which comes from the folded layernorm bias) coloured by the centered norm of the token embedding. We see that the unembed bias is mostly used to give frequency, but that at the tail end of rare tokens, some have tiny unembed norm and big negative bias, and others have high unembed norm and a less negative bias.The case of tied embeddings is messier, because the model wants to do these two very different things at once! But since, again, it doesn’t care about the embedding at all (it’s not that it wants the token’s embedding to be close to zero, it’s that there’s never an incentive to update the gradients). So the effect will be dominated by what the unembed wants, which is getting their logits close to zero.
The unembed doesn’t care about the average token embedding, since adding a constant to every logit does nothing. The model wants a non-trivial average token embedding to use as a bias term (probably), so there’ll be a non-trivial average token embedding (as we see), but it’s boring and not relevant.
So the model’s embedding for the weird tokens will be optimised for giving a big negative logit in the unembedding, which is a weird and unnatural thing to do, and I expect is the seed of your weird results.
One important-ish caveat is that the unembed isn’t quite the transpose of the embed. There’s a LayerNorm immediately before the unembed, whose scale weights get folded into
W_E.T
to create an effective unembed (ieW_U_effective = w[:, None] * W_E.T
), which breaks symmetry a bit. Hilariously, the model is totally accounting for this—if you plot norm of unembed and norm of embed against each other for each token, they track each other pretty well, except for the stupid rare tokens, which go wildly off the side.Honestly, I’m most surprised that GPT-3 uses the same tokenizer as GPT-2! There’s a lot of random jank in there, and I’m surprised they didn’t change it.
Another fun fact about tokenizers (god I hate tokenizers) is that they’re formed recursively by finding the most common pair of existing tokens and merging those into a new token. Which means that if you get eg common triples like ABC, but never AB followed by not C, you’ll add in token AB, and then token ABC, and retain the vestigial token AB, which could also create the stupid token behaviour. Eg ” The Nitrome” is token 42,089 in GPT-2 and ” TheNitromeFan” is token 42,090, not that either actually come up in OpenWebText!
To check this, you’d want to look at a model trained with untied embeddings. Sadly, all the ones I’m aware of (Eleuther’s Pythia, and my interpretability friendly models) were trained on the GPT-NeoX tokenizer or variants, whcih doesn’t seem to have stupid tokens in the same way.
GPT-J uses the GPT-2 tokenizer and has untied embeddings.
Why do you think that GPT-3 has untied embeddings?
Personal correspondance with someone who worked on it.