Chris Olah, Neel Nanda, Catherine Olsson, Nelson Elhage, and a bunch of other people at Anthropic just published “Transformer Circuits,” an application of the Circuits-style interpretability paradigm to transformer-based language models. From their very top level summary:
Can we reverse engineer transformer language models into human-understandable computer programs? Inspired by the Distill Circuits Thread, we’re going to try.
They’ve chosen to publish their work in an interestingly novel format, publishing their first paper, “A Mathematical Framework for Transformer Circuits,” alongside a set of YouTube videos that go into even more detail on their findings. I watched the full YouTube playlist, and I found it absolutely fascinating and would highly recommend it as a way to engage with this research.
Some of my high-level takeaways:
Their signature finding, I think, is that of induction heads. I’ll let you get the explanation of what they are directly from Chris’s video, but essentially they’re a very basic mechanism that all transformer-based language models seem to use that drives their ability to “meta-learn”—that is, improve their accuracy over the course of seeing a larger context. Similarly to the exploration of early vision in Circuits, induction heads give us insight into the basic building blocks of large language models. While I think the authors’ did a great job uncovering this basic building block, I think their explanation of meta-learning as primarily being driven by induction heads mostly leads us to a big open research question, which is how exactly induction heads can be put together to produce the more complex phenomenon we see in large language models.[1]
Probably their most fascinating finding, in my opinion, was their discovery of the “induction bump.” I won’t try to reexplain exactly what the induction bump is, since I think Catherine’s video does such a good job of that, but I will say that, to my knowledge, the authors’ exploration of the induction bump is the first time there’s ever been a detailed, circuits-level analysis of a phase-change that occurs over the course of training in a discrete way. I think this is especially interesting because a lot of stories that people like to tell about how particular safety problems might arise in neural networks often rely on these sorts of phase-change-style transitions (e.g. the development of deception). It seems to me that the existence of something like the induction bump is suggestive that there might be more of these sorts of phase changes hiding in future large training runs, or even in existing large training runs—as Catherine notes, the bump is pretty easy to miss until you break down the loss into individual tokens.[2]
Overall, I think this is clearly the most exciting progress in transparency and interpretability in general since Circuits—and I’m really happy to see it happening in language models, as I’ve previously emphasized I think is important for us to focus on. One thing that I think really sets this sort of transparency and interpretability work apart from the rest is the authors’ emphasis on understanding the mechanistic building blocks underlying their models—with the hope of eventually being able to reverse-engineer them—rather than just, for example, trying to give humans tools to predict what models are doing (without the output of those tools necessarily having any correspondence to what the models are actually doing).
- ↩︎
Some possible future research directions related to understanding how induction heads compose:
I wonder if there’s any sort of general theory of what sorts of computations can be built up of exclusively many layers of induction heads—e.g., as a very simple question, is an induction-heads-only model of computation Turing complete?
I also wonder if induction heads can help us understand why language models often get stuck in loops, as induction heads seem like exactly the sorts of things that would be very prone to looping.
One very concrete question here is how the vowel/consonant neuron (the one that looks for “an”) is able to meta-learn when vowels/consonants are supposed to appear, rather than just looking for hard-coded “an” tokens. That’s an example of a really interesting meta-learning behavior that the model is able to do and I’d be interested in seeing if there’s a way to understand how it could possible be built up just using induction—that is, if we suppose that induction is all there is, how could induction be put together to produce an effect like that?
- ↩︎
One thing I’m still struggling to understand related to the induction bump is how it can be the case that, after the induction bump, large models don’t do relatively better at meta-learning compared to small models. I found that extremely surprising and I feel like I almost don’t believe that that can be the end of the story—qualitatively, we certainly observe much more interesting meta-learning in larger models, and it seems really strange for all of that to just be reflected in an overall loss decrease rather than an increase in the amount of meta-learning. At the very least, I feel like this fact deserves some sort of further explanation—perhaps there are other interesting phase changes hiding in later parts of the loss function that might help explain what’s going on here, or perhaps the claim that the whole phenomenon of “meta-learning” is just task recognition could help shed some light on this result.
I expect to have more detailed thoughts worth sharing as I spend more time with this content, but one thing stands out brightly as a first: This is, head-and-shoulders, the best language model interpretability work to date. I’m impressed at the thoroughness of the theory combined with detailed real examples.
This also seems like a good motivation to go back and study layer reordering (a’la Sandwich Transformers) as a treatment affecting the induced circuits of a model.
(h/t Kevin Wang for pointing out the sandwich transformer paper to me recently)
Thanks for the writeup! The first paper covers the first half of the video series, more or less. I’ve been working on a second paper which will focus primarily on the induction bump phenomenon (and other things described in the second half of the video series), so much more to come there!
Second paper is out: https://twitter.com/catherineols/status/1501250025661206529
This looks really interesting. Is there any intention to use these insights to design even more interpretable models than transformers? I’ve had the feeling that transformer models may be too general-purpose for their own good, in terms of training efficiency and interpretability. By that I mean that, just like fully connected neural networks technically have at least as much computational/representational power as convolutional neural networks, yet they are much harder to train for general image processing than their more constrained counterparts that take full advantage of translational equivariance, transformer-type language models might not have enough constraints to make them efficient enough for AGI.
In these models, some representation of every token is compared against a representation of every other token encountered so far, which gives quadratic complexity for every attention layer at runtime. This then leads to further transformation of the data after each attention block, creating what is effectively a new string of abstract tokens, each of which is some hard-to-interpret combination of the token representations in the level below. The only information added to the vector representation of each token, as far as I understand it, is some vector representing the relative position of the tokens within the string (which itself necessitates a special type of normalization step later on). Otherwise, it’s up to the model to learn to assign implicit roles/functions to each token through the attention module. This hides away the information of what each token is doing, which a more constrained model could instead represent explicitly.
It seems to me that we could do better. For instance, suppose we had a model that had “slots” (I’m thinking something like CPU registers) that it would fill in with token vectors as it went along. The LM would learn to assign functions like “subject”, “verb”, “direct object modifier”, etc., with one part learning which tokens should get routed to which slots, another part learning to predict which slot (e.g., part of speech) will get routed to next based on what information has already been filled in and on the learned rules of syntax, and another part predicting what information should go into the empty slots (allowing it to “read between the lines”). That last part could also be hooked up to a long-memory database of learned relations that it could fill in and update as it accumulates training data (something like what DeepMind published recently: https://deepmind.com/research/publications/2021/improving-language-models-by-retrieving-from-trillions-of-tokens).
Although the role of each slot will be arbitrary and assigned only through training, I think this type of architecture would make it easier to extract semantic roles for the tokens that it reads in, since these semantic roles have explicit locations where they can always be found. In other words, you can use the same method to find out what the LM thinks about the who, what, when, where, why, and how of what it reads or says (along with what it thinks about everything it doesn’t read or say by looking into the “unused” slots). With transformers, this would be much more difficult, since semantic roles are assigned much more implicitly and a lot could be hiding in the weights.
That was just an idea, but I think that intepretibility will come more easily the more we constrain the language model with both functional and representational modularity. Perhaps the work you do could help inform what sorts of constraints would be most effective to that end.