Grokking is a recent phenomena discovered by OpenAI researchers, that in my opinion is one of the most fascinating mysteries in deep learning. That models trained on small algorithmic tasks like modular addition will initially memorise the training data, but after a long time will suddenly learn to generalise to unseen data.
This is a write-up of an independent research project I did into understanding grokking through the lens of mechanistic interpretability. My most important claim is that grokking has a deep relationship to phase changes. Phase changes, ie a sudden change in the model’s performance for some capability during training, are a general phenomena that occur when training models, that have also been observed in large models trained on non-toy tasks. For example, the sudden change in a transformer’s capacity to do in-context learning when it forms induction heads. In this work examine several toy settings where a model trained to solve them exhibits a phase change in test loss, regardless of how much data it is trained on. I show that if a model is trained on these limited data with high regularisation, then that the model shows grokking.
One of the core claims of mechanistic interpretability is that neural networks can be understood, that rather than being mysterious black boxes they learn interpretable algorithms which can be reverse engineered and comprehended. This work serves as a proof of concept of that, and that reverse engineering models is key to understanding them. I fully reverse engineer the inferred algorithm from a transformer that has grokked how to do modular addition (which somehow involves Discrete Fourier Transforms and trig identities?!), and use this as a concrete example to analyse what happens during training to understand what happened during grokking. I close with discussion and thoughts on the alignment relevance of these results.
This is accompanied by a paper in the form of a Colab notebook containing the code for this project, a lot of interactive graphics, and much more in-depth discussion and technical details. In this write-up I try to give a high-level conceptual overview of the claims and the most compelling results and evidence, I refer you to the notebook if you want the full technical details.
This write-up ends with a list of ideas for future directions of this research. I think this is a particularly exciting problem to start with if you want to get into mechanistic interpretability since it’s concrete, only involves tiny models, and is easy to do in a Colab notebook. If you might want to work on some of these, please reach out! In particular, I’m looking to hire an intern/research assistant, and if you’re excited about these future directions you might be a good fit.
Key Claims
Grokking is really about phase changes: To exhibit grokking, we train a model on a problem that exhibits phase changes even when given infinite training data, and train it with regularisation and limited data. If we choose the amount of data such that the regularisation only marginally favours the generalising solution over the memorised solution, we see grokking.
Intuition: Regularisation makes the model ultimately prefer the generalising solution to the memorised solution, but the phase change indicates that the generalising solution is “hard to reach” in some sense. The memorised solution is “easier to reach”, and so is reached first. But due to the regularisation, the model still prefers the generalising solution to the memorised solution, and thus gets to the generalising solution eventually—the grokking result just shows that reaching the memorised solution first does not change this, and that there is a path in model space interpolating between the memorising and generalising solution.
Map inputs x,y→cos(wx),cos(wy),sin(wx),sin(wy) with a Discrete Fourier Transform, for some frequency w
Multiply and rearrange to get cos(w(x+y))=cos(wx)cos(wy)−sin(wx)sin(wy) and sin(w(x+y))=cos(wx)sin(wy)+sin(wx)cos(wy)
By choosing a frequency w=2πnk we get period dividing n, so this is a function of x+y(modn)
Map to the output logits z with cos(w(x+y))cos(wz)+sin(w(x+y))sin(wz)=cos(w(x+y−z)) - this has the highest logit at z≡x+y(modn), so softmax gives the right answer.
To emphasise, this algorithm was purely learned by gradient descent! I did not predict or understand this algorithm in advance and did nothing to encourage the model to learn this way of doing modular addition. I only discovered it by reverse engineering the weights.
In particular, the circuits smoothly develop well before grokking, disproving the ‘grokking is a random walk in model-space’ hypothesis.
This is nicely demonstrated by the metric of excluded loss—which roughly shows how much model performance on the training data depends on the generalising algorithm vs memorising algorithm. We see the use of the generalising algorithm to improve training performance rises smoothly over training, well before the grokking point.
Phase Changes
Epistemic status: I feel confident in the empirical results, but the generalisation to non-toy settings is more speculative
For example, the composition of the previous token head and the induction head in an induction circuit. The previous token head will only reduce loss if the induction head is there and vice versa. So, initially, the gradients creating each component will be weak to non-existent. But once each component starts to form, the gradients on the other component will become stronger. These effects reinforce each other, and creating a feedback loop that eventually accelerates and results in a phase change.
Intuitive explanation of grokking: Regularisation incentivises the model to be simpler, so the model prefers the generalising solution to the memorised solution. The generalising solution is “hard to reach” and the memorising solution is not, so the memorising solution is reached first. But the incentive to find the generalising solution is still there, so the underlying mechanism to induce the phase change is still going on while grokking—memorisation doesn’t change this
Fundamentally, understanding grokking is about understanding phase changes—I don’t claim to fully understand phase changes or grokking, but I claim to have reduced my confusion about grokking to my confusion about phase changes.
I observe several small phase changes in my toy tasks, eg there’s a separate phase change for each digit when learning 5 digit addition.
What Is A Phase Change?
By a phase change, I mean a reverse-S shaped curve[1] in the model’s loss on some dataset, or the model’s capacity for some specific capability. That is, the model initially has poor performance, this performance plateaus or slowly improves, and suddenly the performance accelerates and rapidly improves (ie, loss rapidly decreases), and eventually levels off.
A particularly well-studied motivating example of this is Anthropic’s study of induction heads. Induction heads are a circuit in transformers used to predict repeated sequences of tokens. They search the context for previous copies of the current token, and then attend to the token immediately after that copy, and predict that that subsequent token will come next. Eg, if the 3 token surname D|urs|ley is earlier in the context, and the model wants to predict what comes after D, it will attend to urs and predict that that comes next.
The key fact about induction heads are that there is a fairly narrow band of training time where they undergo a phase change, and go from not existing/functional to being fully functional. This is particularly striking because induction heads are the main mechanism behind how Large Language Models do in-context learning—using tokens from far back in the context to usefully predict the next token. This means that there is a clear phase change in a model’s ability to do in-context learning, as shown below, and that this corresponds to the formation of a specific circuit via a phase change within the model.
My goal here is to convey the intuition of what I mean by a phase change, rather than give a clear and objective definition. At this stage of research, I favour an informal “I know it when I see it” style definition, over something more formal and brittle.
Empirical Observations
Motivation: Modular addition shows clear grokking on limited data[2], but given much more data it shows a phase change in both train and test loss. This motivated the hypothesis that grokking could be observed any time we took a problem with a phase change and reduced the amount of data while regularising:
In several toy algorithmic tasks I observe phase changes (in both test and train loss) when models are trained on sufficient data[3].
Trained on 95% of the training data, 1L Transformer
5 digit addition (1L full transformer)
1L Transformer, data format 1|3|4|5|2|+|5|8|3|2|1|=|0|7|1|7|7|3
Predicting Repeated Subsequences (2L attn-only transformer—task designed to need induction heads)
Data format: 7 2 8 3 1 9 3 8 3 1 9 9 2 5 END - we take a uniform random sequence of tokens, randomly choose a subsequence to repeat, and train the model to predict the repeated tokens.
Finding the max element in a sequence (1L attn-only transformer—task designed to need skip trigrams)
Concretely, the data format is START 0 4 7 2 4 14 9 7 2 5 3 END with exactly one entry ≥10, and the model is trained to output that entry after END. The solution is to learn 10 skip trigrams of the form 14 .. END 14
Grokking = Phase Changes + Regularisation + Limited Data
For all of the above tasks, we can induce grokking by adding regularisation (specifically weight decay) and limiting the amount of training data. Eg for 5 digit addition on 700 examples (see the notebook for the rest):
With enough data (eg single epoch training), the model generalises easily, and with sufficiently little data (eg a single data point), the model memorises. But there is a crossover point, and we identify grokking when training with slightly more data than the crossover point—analogous to the findings of Liu et al that grokking is an intermediate state between comprehension and memorisation. I found the crossover points (here, 700) by binary searching by hand. Intuitively, this is consistent with the idea that grokking occurs because of regularisation favouring the simpler solution. Memorisation complexity increases with the amount of data (approximately continuously) while generalising complexity does not, so there should must eventually be a crossover point.
This is particularly interesting, because these tasks are a non-trivial extension of the original grokking. Previous demonstrations of grokking had an extremely small universe of data and were trained on a substantial fraction of it, suggesting that the model may be doing some kind of clever interpolation. While here, the universe of data is much larger (eg 1010 possible pairs of 5 digit numbers), and model is trained on a tiny fraction of that, yet it still groks.
Explaining Grokking
Epistemic status: The following two sections are fairly speculative—they’re my current best explanation of my empirical findings, but are likely highly incomplete and could easily be totally wrong.
Speculation: Phase Changes Are Inherent to Composition
To understand the link between phase changes and grokking, it’s worth reflecting on why circuits form at all. A priori, this is pretty surprising! To see this, let’s focus on the example of an induction circuit, and speculate on how it could be formed. The induction circuit is made up of two heads, a previous token head and an induction head, which interact by K-composition. Together, these heads significantly improve loss, but only in the context of the other head being there. Thus, naively, when we have neither head, there should be no gradient encouraging the formation of either head.
At initialisation, we have neither head, and so gradient descent should never discover this circuit. Naively, we might predict that neural networks will only produce circuits analogous to linear regression, where each weight will clearly marginally improve performance as it continuously improves. And yet in practice, neural networks empirically do form sophisticated circuits like this, involving several parts interacting in non-trivial, algorithmic ways.[4]
I see a few different possible explanations for this:
A lottery ticket hypothesis-inspired explanation:[5] Initially, each layer of the network is the superposition of many different partial circuit components, and the output of each layer is the average of the output of each component. The full output of the network is the average of many different circuits. Some of these circuits are systematically useful to reducing loss, and most circuits aren’t. SGD will reinforce the relevant circuits and suppress the useless circuits, so the circuits will gradually form.
A random walk explanation: The network wanders randomly around the loss landscape, until it happens to get lucky and find a half-formed previous token head and induction head that somewhat compose. Once it has this, this half-formed circuit is useful for reducing loss and gradient descent can take over and make a complete circuit.
An evolutionary explanation: There’s a similar mystery for how organisms develop sophisticated machinery like the human eye, where each part is only useful in the context of other parts. The explanation I find most compelling is that we first developed one component that was somewhat useful on its own, eg a light-detecting membrane. This component was useful in its own right, and so was reinforced, and later components could develop that depend on it, eg the lenses in our eye.
The evolutionary explanation is a natural hypothesis, but we can see from my toy tasks that it cannot be the whole story. In the repeated subsequence task, we have a sequence of uniform randomly generated tokens, apart from a repeated subsequence at an arbitrary location, eg 7 2 8 3 1 9 3 8 3 1 9 9 2 5 END. This means that all pairs of tokens are independent, apart from the pairs of equal tokens in the repeated subsequence. In particular, this means that a previous token head can never reduce loss for the current token—the previous token will always be independent of the next token. So a previous token head is only useful in the context of an induction-like head that completes the circuit. Likewise, an induction head relies on K-composition with a previous token head, and so cannot be useful on its own. Yet the model eventually forms an induction circuit![6]
A priori, the random walk story seems unlikely be sufficient on its own—an induction circuit is pretty complicated, and it likes represents a very small region in model space, and so seems unlikely to be stumbled upon by a random walk[7]. Thus my prediction is that the lottery ticket hypothesis is most of what’s going on[8] - an induction head will be useless without a previous token head, but may be slightly useful when composing with, say, a head that uniformly attends to prior tokens, since part of its output will include the previous token! I expect that all explanations are part of the picture though, eg this seems more plausible if the uniform head just so happens to attend a bit more to the previous token via a random walk, etc.
Drawing this back to phase changes, the lottery ticket-style explanation suggests that we might expect to see phase changes as circuits form. Early on in circuit formation, each part of the circuit is very rough, so the effect on the loss of improving any individual component is weak, which means the gradients will be small. But as each component develops, each other component will become more useful, which means that all gradients will increase together in a non-linear way. So as the circuit becomes closer to completion we should expect an acceleration in the loss curve for this circuit, resulting in a phase change.
An Intuitive Explanation of Grokking
With this explanation, we can now try to answer the question of why grokking occurs! To recap the problem setting, we are training our model on a problem with two possible solutions—the memorising algorithm and the generalising algorithm. We apply regularisation and choose a limited amount of data, such that the generalising solution is marginally simpler than the memorising solution[9] and so our training setup marginally prefers the generalising solution over the memorising solution. Naively, we expect the model to learn the generalising solution.
But we are training our model on a problem whose solution involves multiple components interacting to form a complete circuit. So, early in training, the gradients incentivising each component of the generalising solution are weak, because they need the parts to all be formed and lined up properly. Memorisation, however, does not require several components to be lined up in careful and coordinated way[10], so it does not have artificially weak gradients at the start. Thus, at the start, memorisation is incentivised more strongly than generalisation and the model memorises.
So, why does the model shift from memorisation to generalisation? Eventually the training loss plateaus post-memorisation—loss is falling and total weights are rising so eventually the gradients towards lower loss (ie to memorise better) balances with the gradients towards lower weights (ie to be simpler) and cancel out. But they don’t perfectly cancel out. If there’s a direction in model space that allows it to memorise more efficiently[11], then both gradients will encourage this direction. And the natural way to do this is by picking up on regularities in the data—eg, you can memorise modular addition twice as efficiently by recognising that x+y=y+x. This is the same process that leads the model to generalise in the infinite data case—it wants to pick up on patterns in the data.[12]
So the model is still incentivised to reach the generalising solution, just as in the infinite data case. But rather than moving from the randomly initialised model to the generalising model (as in the infinite data case), it interpolates between the memorising solution and the generalising solution. Throughout this process test loss remains high—even a partial memorising solution still performs extremely badly on unseen data! But once the generalising solution gets good enough, the incentive to simplify by deleting the remnants of the memorising solution dominates, and the model clears up the memorising solution, finally resulting in good test performance. Further, as in the infinite data case, the closer we get to the generalising solution the more the rate of change of the loss accelerates. So this final shift happens extremely abruptly—manifesting as the grokking phase change!
As a final bit of evidence, once the model has fully transitioned to the generalising solution it is now inherently simpler, and the point where the incentive to improve loss balances with the incentive to be simpler is marginally lower—we can observe in the graphs here that the model experiences a notable drop in train loss post grokking.
My explanation above asserted that phase change are a natural thing to expect with the formation of specific circuits in models. If we buy the hypothesis that most things models do are built up out of many interpretable circuits, then shouldn’t we expect to see phase changes everywhere whenever we train models, rather than smooth and convex curves?
My prediction is that yes, we should, and that in fact we do. But that larger models are made up of many circuits and, though each circuit may form in a phase change, the overall loss is made up out of the combination of many different capabilities (and thus many different circuits). And circuits of different complexity/importance likely form at different points in training. So the overall loss curve is actually the sum of many tiny phase changes at different points in training, and this overall looks smooth and convex. Regularities in loss curves like scaling laws may be downstream of statistical properties of the distribution of circuits, which become apparent at large scales. We directly observe the phase change-y ness from the loss curves in these simple toy problems because the problems are easy enough that only one/a few circuits are needed.
Some evidence for this hypothesis:
5 digit addition(toy problem)- We can decompose the loss into the sum of 6[13] components—the loss on each of the 6 digits in the sum. When we do this, we observe separate phase changes in each digit, resulting in the many small non-convexities in the overall loss curve.
The ordering of phase changes is not stable between runs, though token 0 and token 1 tend to be first[14]
This isn’t specific to 5 digit, eg 15 digit addition shows 16 separate phase changes
Skip Trigrams(toy problem)- The model learns 10 different skip trigrams, 10 .. END 10, 11 .. END 11, etc. Each skip trigram shows a separate phase change
Notably, each phase change happens at approximately the same time, so the overall curve looks less bumpy than 5 digit addition. This makes sense, because each skip trigram is “as complex” as the others, while learning to add some digits is much harder than others.
Induction Heads—Induction heads are the best studied example of a specific circuit through training, and there we see a clear phase change in LLMs up to a 13B transformer. Each head should be “as complex” as the others, so it makes sense that the all occur at approximately, but not exactly, the same time.[15]
AlphaZero Capabilities: One finding in DeepMind’s AlphaZero Interpretability paper was that there is a phase change in the model’s capabilities, where it learns to represent a lot of chess concepts around step 32,000.
Summary
I think this is highly suggestive evidence that there is a deep relationship between grokking and phase changes, and that grokking occurs when models with a phase change are trained with regularisation and limited data. I present some compelling (to me) explanations of what might be behind the phase change behaviour, and show how this model explains grokking and predicts several specific empirical observations about grokking. I don’t claim to fully understand phase changes or grokking, but I do claim to have substantially reduced my confusion about grokking to my confusion about phase changes.
Modular Addition
Epistemic status: I feel pretty confident that I have fully reverse engineered this network, and have enough different lines of evidence that I am confident in how it works. My explanation of how and why it develops during training is shakier.
This algorithm operates via using trig identities and Discrete Fourier Transforms to map x,y→cos(w(x+y)),sin(w(x+y)), and then extracting x+y(modp)
This algorithm can be clearly read off from the weights. If we apply a Discrete Fourier Transform to the input space and apply the Transformer Circuits framework, the structure of the network and its resulting algorithm becomes clear.
The model naturally forms several sub-networks that calculate the sumncos(w(x+y−z)) in different frequencies and add these to form the logits. This can be seen by a clear clustering of the neurons
Within a cluster, individual neurons clearly represent interpretable features for a single frequency.
To emphasise, this algorithm was discovered purely via gradient descent, not by me. I didn’t think of this algorithm until I reverse engineered it from the weights!
The evolution of this algorithm can be clearly seen during training, and systematic progress towards the generalising circuit can be seen well before the grokking point
With excluded loss, we can see the model interpolate between memorisation and generalisation. Train loss performance depends substantially on the generalising circuit well before we see a significant change in test loss.
Model Details
In this section I dive deeply into one specific and well-checkpointed model trained to do modular addition. See model training code for more details, but here are the key points:
The model is trained to map x,y to z≡x+y(mod113) (henceforth 113 is referred to as p)
1L Transformer
Learned positional embeddings
Width 128
No LayerNorm
ReLU activations
Input format is x|y|= , where x,y are one-hot encoded inputs, and = is an extra token.
Trained with AdamW, with weight decay 1 and learning rate 10−3
Full batch training, trained on 30% of the data (ie the 1132 pairs of inputs) for 40,000 epochs[16]
This is a 1L transformer, with no LayerNorm and learned positional embeddings, trained with AdamW with weight decay 1, and full batch training on 30% of the data (the data is the 1132 pairs of numbers modp). The
Overview of the Inferred Algorithm
The key feature of the algorithm is calculating cos(w(x+y)),sin(w(x+y)) with w=2πpk - this is a function of x+y and be mapped to x+y, and because cos(wx) has period pk we get the (modp) part for free.
More concretely:
Inputs x,y are given as one-hot encoded vectors in Rp
Calculatescos(wx),cos(wy),sin(wx),sin(wy) via a Discrete Fourier Transform (This sounds complex but is just a change of basis on the inputs, and so is just a linear map)
w=2πpk, k is arbitrary, we just need period dividing p
Calculatescos(wx)cos(wy),cos(wx)sin(wy),sin(wx)cos(wy),sin(wx)sin(wy) by multiplying pairs of waves in x and in y
Calculatescos(w(x+y))=cos(wx)cos(wy)−sin(wx)sin(wy) and sin(w(x+y))=sin(wx)cos(wy)+cos(wx)sin(wy) by rearranging and taking differences
Calculatescos(w(x+y−z))=cos(w(x+y))cos(wz)+sin(w(x+y))sin(wz) via a linear map to the output logits z
This has an argmax at z≡x+y(modp), so post softmax we’re done!
There are a few adjustments to implement this algorithm in a neural network:
The model’s activations at any point are vectors(/tensors). To represent several variables, such as cos(wx),sin(wx), these are stored as different directions in activation space. When the vector is projected onto those dimensions, the coefficient is the relevant variable (eg cos(wx))
The model runs the algorithm in parallel for several different frequencies[17] (different frequencies correspond to different clusters of neurons, different subspaces of the residual stream, and sometimes different attention heads)
Background on Discrete Fourier Transforms
A key technique in all that follows is Discrete Fourier Transforms (DFT). I give a more in-depth explainer in the colab, but here’s a rough outline—I expect this requires familiarity with linear algebra to really get your head around. The key motivating observation is that most activations inside the network are periodic and so techniques designed to represent periodic functions nicely are key. Eg the attention patterns:
In Rp, we have a standard basis of the p unit vectors. But we can also take a basis of p cosine and sine waves, F∈Rp×p, where F0=(1,1,...,1) is the constant vector, and F2k−1=cos(2πpkx) and F2k=sin(2πpkx) are the cosine and sine wave of frequency w=2πpk (henceforth referred to as frequency w=k and represented as coskx for brevity) for k=1,...,p−12. Every pair of waves has dot product zero, unless they’re the same wave (ie it’s an orthogonal basis). If we normalise these rows, we get an orthonormal basis of cosine and sine waves (so F−1=FT). We refer to these normalised waves as Fourier Components and this overall basis as the 1D Fourier Basis.
We can apply a change of basis to the 1D input space Rp to F, and this turns out to be a much more natural way to represent the input space for this problem, as the network learns to operate in terms of sine and cosine waves. Eg, the fourth column of WEFT is the direction corresponding to sin2x in the embedding for WE. If we apply this change of basis to both the input space for x and for y we apply a 2D DFT, and can represent any function as the linear combination of terms of the form sinw1xcosw2y (or cos(w1x)cos(w2y), Const∗cos(w2y), etc). This is just a change of basis on Rp×p=Rp2, and terms of the form sinw1xcosw2y (ie the outer product of each pair of rows in F) form an orthogonal basis of p2 vectors (henceforth referred to as the 2D Fourier Basis).
Importantly, this is just a change of basis! If we choose any single activation in the network, this is a real valued function on pairs of inputs x,y∈Rp×p, and so is equivalent to specifying a p2 dimensional vector. And we can apply an arbitrary change of basis to this vector. So we can always write it as a linear combination of terms in the 2D Fourier Basis. And any vector of activations is a linear combination of 2D Fourier terms times fixed vectors in activation space. If a function is periodic, this means that it is sparse in the 1D or 2D Fourier Basis, and this is what tells us about the structure of the algorithm and weights of the network.
Reverse Engineering the Algorithm
Here, I present a case for how I was able to reverse engineer the algorithm from the weights. See the Colab and appendices (attention and neuron) for full details, my goal in this section is to roughly sketch what’s going on and why I’m confident that this is what’s going on.
Theory: Naively, this seems like the hard part, but is actually extremely easy. The key is that we just need to learn the discretised wave on x∈[0,1,...,p−1], not for arbitrary x∈R. x is input into the network as a one-hot encoded vector, and the multiplied by a learned matrix WE. We can in fact learn any function f:[0,1,...,p−1]→R[18]
Conveniently, F, the matrix of waves cos(wx),sin(wx), is an orthonormal basis. So WEFT will recover the direction of the embedding corresponding to each wave Const,cosx,sinx,cos2x,... - in other ways, extracting cos(wx),sin(wx) is just a rotation of the input space.
Evidence: We can use the norm of the embedding of each wave to get an indicator of how much the network “cares” about each wave[19], and when we do this we see that the plot is extremely sparse. The model has decided to throw away all but a few frequencies[20]. This is very strong evidence that the model is working in the Fourier Basis—we would expect to see a basically uniform plot if this was not a privileged basis.
Calculating 2D products of waves cos(wx)cos(wy),cos(wx)sin(wy),sin(wx)cos(wy),sin(wx)sin(wy)
Theory: A good mental model for neural networks is that they are really good at matrix multiplication and addition, and anything else takes a lot of effort[21]. As so here! As we saw above, creating cos(wx),sin(wx) is just a rotation, and the later rearranging and map to the logits is another linear map, but multiplying the terms together is hard and non-linear.
There are three non-linear operations in a 1L transformer—the attention softmax, the element-wise product of attention and the value vectors, and the ReLU activations in the MLP layer. Here, the model uses both ReLU activations and element-wise products with attention to multiply terms[22].
The neurons form 5[23] distinct clusters for each frequency, and each neuron in the cluster for frequency w has its activation as a linear combination of 1,cos(wx),sin(wx),cos(wy),sin(wy),cos(wx)cos(wy),cos(wx)sin(wy),sin(wx)cos(wy),sin(wx)sin(wy).[24] Note that, as explained above, the neuron activation in any network can be represented as a linear combination of products of Fourier terms in x and Fourier terms in y (because they form a basis of Rp×p). The surprising fact is that this representation is sparse! This can be visually seen as neuron activations being periodic:
Evidence: The details of how the terms are multiplied together are highly convoluted[25], and I leave them to the Colab notebook appendices. But the neurons do in fact have the structure I described, and this can be directly observed by looking at their values. And thus, by this point in the network it has computed the product terms.
For example, the activations for neuron 0 (as plotted above) are approximately 109−39(cos42x+cos42y)−76(sin42x+sin42y)+36(cos42xsin42y+sin42xcos42y)−10cos42xcos42y+38sin42xsin42y (these coefficients can be calculated by mapping the neuron activation into the 2D Fourier Basis). This approximation explains >90% of the variance in this neuron[26]. We can plot this visually with the following heatmap:
Zooming out, we can apply a 2D DFT to all neuron activations, ie writing all of the neuron activations as a linear combinations of terms of the form cos42xcos42y times vectors, and plotting the norm of each vector of coefficients. Heuristically, this is telling us what terms are represented by the network at the output of the neurons. We see that the non-trivial terms are in the top row (of the form coswx,sinwx) or the left column (of the form coswy,sinwy) or in a cell of 2D cells along the diagonal (of the form cos(wx)cos(wy),cos(wx)sin(wy),sin(wx)cos(wy),sin(wx)sin(wy) - notably, a product term where both terms have the same frequency).
Calculating cos(w(x+y)),sin(w(x+y)) and calculating logits
Theory: The operations mapping cos(w(x+y))=cos(wx)cos(wy)−sin(wx)sin(wy) and sin(w(x+y))=cos(wx)sin(wy)+sin(wx)cos(wy) are linear, and the operations mapping this to cos(w(x+y−z))=cos(w(x+y))cos(wz)+sin(w(x+y))sin(wz) are also linear. So their composition is linear, and can be represented by a single matrix multiplication. The neurons are mapped to the logits by L=WUWoutN, and so the effective weight matrix Wlogit=WUWout must represent both of these operations (if my hypothesis is correct). Note that Wlogit is a p×dmlp matrix, mapping from MLP-space to the output space.
Evidence: We draw upon several different lines of evidence here.
We show that the terms cos(w(x+y)),sin(w(x+y)) are computed as follows: We repeat the above analysis to find terms represented by the neurons on the logits, we find that the terms in the top row and left column cancel out. This leaves just diagonal terms, corresponding to products of waves of the same frequency in x and y, exactly the terms we need. We also see that the 2x2 blocks are uniform, showing that cos(w(x+y)) and sin(w(x+y)) have the same coefficient. Further analysis shows that everything other than cos(w(x+y)),sin(w(x+y)) for these 5 frequencies is essentially zero.
We now show that the output logits produce cos(w(x+y−z))=cos(w(x+y))cos(wz)+sin(w(x+y))sin(wz) for each of the 5 represented frequencies (where z is the term corresponding to the output logits). The neurons form clusters for each frequency, and when we plot the columns of Wlogit corresponding to those frequencies, and apply a 1D DFT to the output space of Wlogit, we see that the only non-trivial terms are cos(wz),sin(wz) - ie the output logits coming from these neuron clusters is a linear combination of cos(wz),sin(wz).
We can more directly verify this by writing approximating the output logits as a sum of ∑w∈[14,35,41,42,52]Awcos(w(x+y−z)) and fitting the coefficients Aw. When we do this, the resulting approximated logits explains 95% of the variance in the original logits. If we evaluate loss on this approximation to the logits, we actually see a significant drop in loss, from 2∗10−7 to 4.7∗10−8
Evolution of Circuits During Training
Note: For this section in particular, I recommend referring to the Colab! That contains a bunch of interactive graphics that I can’t include here, where we can observe the development of circuits during training.
Now that we understand what the model is doing during training on a mechanistic level, we can directly observe the development of these circuits during training. The key observation is that the circuits develop smoothly, and make clear and systematic progress towards the generalising circuit well before there is notable test performance. For example, the graph of the norms of the embedding of different waves becomes sparse. In particular, this disproves the natural hypothesis that grokking is the result of a random walk on the manifold of models that perform optimally.
Sum of squared weights—the main regularisation is weight decay, so paying attention to the change in sum of squared weights is key to understand the effect of the regulariser
This uses the transformer circuits notation - Win,Wout We canare MLP weights, WU is the unembed, WE is the embed, WQ,WK,WO,WV are the attention weights
Excluded loss: Intuitively, this metric tracks how much the model’s performance on the training data comes from the generalising algorithm vs the memorisation algorithm, by ablating the model’s ability to use a specific frequency w. This serves as a progress measure, quantitatively tracking the model’s progress towards the correct algorithm.
Formally, we calculate the logits, use a 2D DFT to write them as a linear combination of waves in x and waves in y, and set the terms corresponding to cos(w(x+y)),sin(w(x+y)) to zero and keep the other terms the same. Then measure the train loss.
If the model is using the addition identities (ie, with the generalising algorithm), this will heavily damage performance. If it isn’t (ie, the memorising algorithm), then this is just deleting 2 directions among 1132, and so should have no effect.
Trig loss: Another metric is to take the output logits, and only allow them to use the directions corresponding to cos(w(x+y)),sin(w(x+y)) for the 5 relevant frequencies. Intuitively, this represents the model’s performance with the generalising algorithm preserved but the memorising algorithm removed.
Here we evaluate performance on both the train and test set—trig loss for the train and test set is almost equal (the lines are identical), showing that it’s purely tracking the generalising algorithm
We can now use these metrics to analyse the separate stages of training:
Memorisation (Epoch 0-1K): The model memorises. We see a dramatic drop in train loss, and a rise in test loss.
We see that the incentive to memorise is far stronger than the incentive to have low weights, as total weights rise a lot
We see that it’s memorising rather than generalising as excluded loss is very close to train loss
Interpolation (Epoch 1K-9K): The model is interpolating between memorisation and generalisation.
We see this in the smooth divergence of excluded loss from train loss—this shows that the model is deriving more and more of its training performance from the generalising algorithm.
Intuitively, the model is trying to memorise “more efficiently” by picking up on patterns in the data. Some directions in neuron space are functions of x+y(modp), namely cos(w(x+y)),sin(w(x+y)) terms, so it learns to use these directions to represent more and more information.
Train loss mostly plateaus in this period—this shows that although the model is changing its underlying algorithm, it does this by interpolating between the purely memorising circuit and the purely generalising circuit
It surprises me that the model can interpolate between two circuits while retaining good performance!
This stage begins when the incentive for low train loss and the incentive for low weights balance, and train loss abruptly levels out.
This is particularly clear from epoch 1K-3K, when the weights for the MLP layer Win,Wout drop in norm and meet the other weight matrices. Given a matrix product AB, we can always scale up A→kA and scale down B→1kB and keep the product AB the same. The sum of squared weights is minimised when |A|=|B|, so all weight matrices between the embed and unembed should have equal norm—this minimises total weights and keeps loss the same. The fact that this does not happen in the first stage shows that the incentive for small weights is too weak and has little effect
Further, this is suggestive evidence that most of the memorisation comes from changes in the MLP—this matches the hypothesis that memorisation involves “simpler” circuits that don’t involve complex composition, and so doesn’t need a phase change.
Generalisation (Epoch 8K-12K): We see a phase change in trig loss. Trig loss represents “model performance if we delete the memorisation circuit”, so this shows that the generalising circuit rapidly develops in a phase change.
This matches my (admittedly post-hoc) hypothesis that all circuits see phase changes—as each component of the circuit forms, it reinforces the gradients on all other components.
We further see a rapid acceleration in the decrease in total weights, supporting the idea that there is some cascade as the circuit forms.
Cleaning Up Noise (Epoch 9K-13K): Somewhat in parallel with the above phase change in trig loss, there is a slightly delayed phase change in test loss. The slight delay indicates that good test loss performance requires both a strong generalising circuit, and the removal of the memorising circuit.
The rapid removal of the memorisation circuit presumably comes from weight decay—once the memorisation circuit’s contribution to loss is low enough, the total weight benefits of removing it outweigh the contribution to loss.
Stability (Epoch 13K-end)[27]: The model has now transitioned to using the generalising circuit and all our metrics settle down
The model still has some remnants of the memorisation circuit, as train loss is still less than test loss. This must come from directions other than cos(w(x+y)),sin(w(x+y)) in logit space, as trig loss is identical on the train and test data.
Discussion
Alignment Relevance
Ultimately, my goal is to do research that is relevant to aligning future powerful AI systems. Here I’ve studied tiny models
(Note—the true reason I did this project was that I was on holiday, and got extremely nerd-sniped, so this is all post-hoc reasoning!)
Emergent capabilities:
Is grokking relevant to large models? Probably not—models like GPT-3 and PaLM are trained on functionally infinite data, and rarely see a data point twice. In fact, there’s a strong incentive to avoid this, as repeated data can severely damage performance.
The main way I think this is directly relevant to large models is the indication that phase changes are everywhere. If most circuits emerge as phase changes during training, this suggests we should expect emergent capabilities to be a frequent occurence. In particular, this may occur for particular dangerous capabilities, such as deception or situational awareness
This would make alignment a harder problem, as it is less likely that we’ll see warning shots. It also limits our ability to, eg, continuously test a system for misalignment or deception throughout training. This strategy might work if the system will first try to be deceptive badly before it is superhuman at lying, but will fail if the system rapidly jumps from not even trying to be deceptive to being superhuman.
However, this is highly speculative—even if this hypothesis holds, these capabilities may be implemented by many different circuits, and so still see gradual development
A laboratory for training dynamics + generalisation: There are many open questions about exactly why neural networks work, why they learn to generalise to unseen data, what kinds of generalising algorithms they learn, and how this can be influenced. These are extremely hard questions to study, and I don’t feel that I’ve understood them as a result of this work. But I think the modular addition transformer may be a good laboratory to examine these questions, as it’s a concrete, small model that can be easily trained and whose circuits can be clearly inspected. Further, understanding these questions is a key question for alignment—aligning systems will likely require predicting when they would learn to be deceptive vs aligned and how to influence them.
A particularly important direction may be directly training systems on interpretability-based metrics. Naively, if we could create an automated tool to detect deception and put this in the loss function, our system may never learn to be deceptive. But by default, I would expect this kind of strategy to fail badly, because gradient descent will just find a way to Goodhart the tool. I expect a good (though still highly limited) in-road into this question would be to train the modular addition system on interpretability-based metrics, and see the effects on the training dynamics and the resulting circuits.
Importance of interpretability: A high-level motivation behind this work is a belief that mechanistic interpretability is key to understand how neural networks work. They seem to learn interpretable circuits, and these determine both the model’s capabilities, and how those capabilities are implemented. It seems very difficult to get traction on what’s really going on in models without engaging with these circuits—analogous to doing chemistry without discovering the periodic table. I take this work as a proof of concept of that—I took the weird and confusing phenomena of grokking, and feel that I have significantly de-mystified it by fully reverse engineering the underlying circuits.
This is a particularly notable with grokking, as here we studied the competition between two circuits with comparable loss. The standard way to predict what a model will learn is by expecting it to have low loss, but this breaks down if multiple mechanistically different solutions have comparable performance. This is extremely relevant for alignment as, for example, a deceptive system will be incentivised to output exactly what we’d expect from an aligned solution, and so get the same loss.
This work also reinforces my belief that the circuits agenda is possible at all—neural networks seem to want to be understood, if we can just find the right way to reverse engineer them! In particular, the fact that I discovered an unexpected and new[28] algorithm for doing modular addition via reverse engineering.
Limitations
There are several limitations to this work and how confidently I can generalise from it to a general understanding of deep learning/grokking, especially the larger models that will be alignment relevant. In particular:
The modular addition transformer is a toy model, which only needs a few circuits to solve the problem fully (unlike eg an LLM or an image model, which needs many circuits and may lead to different behaviour)
The model is over-parametrised, as is shown by eg it learning many redundant neurons doing roughly the same task. It has many more parameters than it needs for near-perfect performance on its task, unlike real models.
In particular, models like language models could always learn more circuits to achieve marginally better task performance, and so need to use their parameters efficiently. This may act as a form of implicit regularisation.
I only study weight decay as the form of regularisation.
My explanations of what’s going on rely heavily on fairly fuzzy and intuitive notions that I have not explicitly defined: crisp circuits, simplicity, phase changes, what it means to memorise vs generalise, model capacity, etc.
Of my four tasks exhibiting phase changes, I have only fully interpreted modular addition.[29]
Future Directions
I see a lot of future directions building on this work. If you’d be interested in working on any of these, please reach out! I’d love to chat and possibly advise or collaborate on a project. I’m also currently looking to hire an intern/research assistant to help me work on mechanistically interpretability projects (including but not limited to these ideas), and if you feel excited about any these directions you may be a good fit.
Investigating the ‘phase changes are everywhere’ hypothesis.
Focusing on larger models on complex tasks
Training a larger model that we know how to interpret, making many checkpoints and inspecting circuits during training
Looking for emergent capabilities in open source well-checkpointed training runs (Eg Mistral and GPT-J)
Performance on benchmarks or specific questions from a benchmark
Simple algorithmic tasks—eg few shot learning, addition, sorting words into alphabetical order, completing a rhyming couplet, matching open and close brackets
Look at attention heads on various texts and see if any have recognisable attention patterns (eg start of word, adjective describing current word, syntactic features of code like indents or variable definitions, most recent open bracket, etc).
More elaborate toy problems that involve several different tasks
Eg, training a model with separate heads for modular addition, subtraction, multiplication, etc. Do we see separate phase changes?
Training models for the other tasks in the grokking paper with more data, and seeing if we see phase changes but no grokking.
Investigating the ‘phase changes are inherent to composition’ hypothesis
If we directly look at the gradients on different components of the circuit, do we see this mutual reinforcement and cascading?
Can we show this analytically in some toy setting? What are the minimal properties we need to see phase changes?
I predict the most tractable will be skip trigrams, given how simple they are. We can take even simpler versions, by making the only parameters a QK and OV matrix, removing the attention softmax, etc.
Are phase changes specific to weirdnesses of the Adam optimizer? Do we get them for SGD or Momentum?
Note: Once a model has memorised, loss can be around 1e-7. SGD gradients scale with loss while Adam is normalised, so you’ll need to increase your learning rate a lot.
Barak et al may be another good setting to analyse—looking for phase changes in the k-bit parity problem
Loose threads in this project
Interpreting my other tasks with phase changes, and examining the circuits
Preliminary investigation shows that 5 digit addition is using a variant of the trig based algorithm
Interpreting the memorisation circuit in the modular addition transformer—understanding memorisation in transformers seems exciting to me!
This is likely easiest on random labels with no structure.
Why is there another phase change at epoch 43K?! What’s the algorithm after this?
There’s some circuit competition between eg frequency 25 and 31 from epoch 12K to 14K, or between multiplying terms with the attention patterns or the ReLUs. What’s up with that?
Why do the given frequencies form, and not others? Why 5 or 6 frequencies, and not more or less?
Train the model for a bunch of different random seeds, and see how consistent these observations are.[30]
Training dynamics. Note that a 2L MLP can also grok modular addition, and may be easier to study.
What happens when we train on interpretability-inspired metrics like excluded loss[31]? Can we accelerate the grokking point? Can we disincentivise it? Can we incentivise or disincentivise specific frequencies?
How do forms of regularisation other than weight decay affect the model and its inductive biases.
How and why does the model learn to form separate clusters of neurons? Can we predict neuron clustering in advance? Can we manually shift the initialisation to change the clustering?
Can we replicate a lottery ticket-style approach here? What subnetworks are useful for solving this problem? And if only some subnetworks are useful, why do all neurons ultimately end up being used?
How general is this algorithm? Can we find Fourier-style algorithms in LLMs? Especially when they learn to do addition
Note, I expect this to be much easier to detect in models that tokenize each digit separately—the GPT-3 tokenizer completely breaks place value: 28|79|23|598|23|45|987|249|234|0000|23|47|89|03|00000|700|9|02|10. Sadly, the only tokenizer I know of like this is PaLM’s.
The GPT-3 tokenizer mostly creates 2 digit tokens, so I’d start by trying to use a linear probe to extract cos(w(z)),sin(w(z)) in base 100 from the model for each frequency, where z is the correct output digit (or possible the output if you incorrectly carry a 1)
Getting into the field: I think this problem is notable for being fairly simple and involving small models that it’s tractable to completely understand. If you’re interested in getting into mechanistic interpretability and are new to the field, I think this may be a great place to start! As a concrete first project, I recommend training your own transformer for modular subtraction, figuring out the analogous algorithm it should be using, and see if you can replicate my analysis to reverse engineer that.
Meta
Notable Observations
Some scattered thoughts that feel exciting to me, but didn’t naturally fit into this write-up:
Slingshot mechanism: (see colab)The Slingshot Mechanism was a fascinating paper from Thilak et al, that found that grokking seemed tied to ‘slingshots’, dramatic loss spikes, that it was hypothesised flung the model to different regions in the loss landscape.
I found that these slingshots were not necessary to achieve grokking[32], and further that the slingshots are the result of a float32 precision error in log_softmax.
This can be fixed by casting your logits to float64 before taking log_softmax. If I keep my training code exactly the same but remove this casting, my loss curve instead looks like this monstrosity:
Explanation: Concretely, the smallest non-zero value that log_softmax can output in float32 is 1.19e-7. So once the model has memorised sufficiently well that loss approaches the scale of 1e-7, some data points will have loss rounding to zero. This also rounds the gradients of their loss to zero, equivalent to them not being in the training set. When a model memorises data, it wants to repurpose all other parameters and set unseen data points to extreme values, and this precision error means this includes some training data points. This introduces a dodgy gradient, and the EWMA of Adam ensures it persists for long enough for a significant loss spike
Note that this is an inherent problem of float32, not a flaw in log_softmax. For small values of x, log_softmax(x) is equal to log(1+x), and this is approximately x. 1.19e-7 is the smallest value such that 1+x!=1 in float32.
Float32 isn’t the whole story. With minibatch training and float64, we also see loss spikes, though rarer and less severe
Weight Decay: Naively, for a fixed training dataset size, the model should only grok for sufficiently high weight decay. But this is not actually true! Lower weight decay just delays grokking rather than stopping it.
Intuition: The model will have an incentive to switch to the generalising circuit so long as it has lower total weight than the memorising circuit. Importantly, this holds for any non-zero weight decay. Smaller weight decay just leads to a smaller incentive
Explaining circle plots ofWE: A notable finding of Liu et al and the original grokking paper is that the embedding matrix of modular addition can be plotted as a circle, by the first two principal components or t-SNE respectively. Now that we understand the underlying circuits, we can see that this clearly follows from the fact that the embedding matrix has orthogonal rows for sin(wx),cos(wx) - these dimensionality reduction techniques are just finding two directions of the same frequency, and observing that a scatter plot of sin(wx),cos(wx) is a circle
2L MLPs can grok modular addition: A surprising (to me) result from Liu et al is that a 2L MLP can grok modular addition[33]. I replicated this result, and found that it learned the same algorithm (here, purely using the ReLU/GeLU activations to multiply the waves in x and y). The neurons formed a cluster for every single frequency, with 3-7 neurons per cluster (width 256), and the input and output weights were sparse in the 1D Fourier basis.
Acknowledgements
This work was done as independent research after leaving Anthropic, but benefitted greatly from my work with the Anthropic interpretability team and skills gained, especially from an extremely generous amount of mentorship from Chris Olah. It relies heavily on the interpretability techniques and framework developed by the team in A Mathematical Framework for Transformer Circuits, and a lot of the analysis was inspired by the induction head bump result.
Thanks to Noa Nabeshima for pair programming with me on an initial replication of grokking, and to Vlad Mikulik for pair programming with me on the grokked induction heads experiment.
Thanks to Jacob Hilton, Alex Ray, Xander Davies, Lauro Langosco, Kevin Wang, Nicholas Turner, Rohin Shah, Vlad Mikulik, Janos Kramar, Johannes Treutlein, Arthur Conmy, Noa Nabeshima, Eric Michaud, Tao Lin, John Wentworth, Jeff Wu, David Bau, Martin Wattenberg, Nick Cammarata, Sid Black, Michela Paganini, David Lindner, Zac Kenton, Michela Paganini, Vikrant Varma, Evan Hubinger (and likely many others) for helpful clarifying discussions about this research, feedback, and helping me identify my many ill-formed explanations!
Generously supported by the FTX regrantor program
Author Contributions
Tom Lieberum significantly contributed to the early stages of this project—replicating grokking for modular subtraction, discovering it was possible in a 1L transformer with no LayerNorm, and observing the strongly periodic behaviour of activations and weights
Neel Nanda led the rest of this project—fully reverse engineering the modular addition transformer, analysing it during training, and discovering and analysing the link between grokking and phase changes. He wrote this write-up, and it is written from his perspective.
Feedback
I’d love to hear feedback on this work—parts you find compelling, parts you’re skeptical of, parts I explained poorly, cases for why Colab notebooks are a terrible format for papers, etc. You can comment here, or reach me at neelnanda27@gmail.com.
Citation Info
Please cite this work as:
@misc{https://doi.org/10.48550/arxiv.2301.05217,
doi = {10.48550/ARXIV.2301.05217},
url = {https://arxiv.org/abs/2301.05217},
author = {Nanda, Neel and Chan, Lawrence and Lieberum, Tom and Smith, Jess and Steinhardt, Jacob},
keywords = {Machine Learning (cs.LG), Artificial Intelligence (cs.AI), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Progress measures for grokking via mechanistic interpretability},
publisher = {arXiv},
year = {2023},
copyright = {arXiv.org perpetual, non-exclusive license}
}
Where possible, I train the models on infinite data (ie, the data is randomly generated each time, and there are enough possible data points that nothing gets repeated) - in this case I do not distinguish between train and test loss, as train loss is always on unseen data.
This is even more mysterious when you consider how each head has both an attention (QK) circuit and a value (OV) circuit. The previous token head’s QK circuit needs to compose with the positional embeddings and its OV circuit needs to compose with the embedding’s output. The induction head’s QK circuit needs to compose with the previous token head’s output on the K-side, and the embedding on the Q-side, and the OV circuit needs to compose with the embedding and unembedding to be the identity. This is a lot of synchronised and non-trivial interaction! But we know that real networks form circuits with as many moving parts as this, or more, eg curve detecting circuits.
Eric Michaud gives some evidence for this by finding that, if we project the model’s initialised embedding onto the principal components of the final embedding, we still see a circle-like thing—suggesting that the model reinforced directions in embedding space that luckily had good properties.
I have not verified that the model actually forms an induction circuit of the form described in this particular task. But the structure of attention means it must be doing some kind of composition, because a single attention head cannot attend to a previous token based on its neighbours
I later show in the case of modular addition that it is definitively not a random walk, as we can see clear progress towards the true answer before grokking
This is a more complex statement than it looks, since each solution can always become more complex to result in lower loss. Eg doubling WU will always decrease loss if there’s 100% accuracy. I operationalise this as ‘which solution is simpler if we set them both to achieve the same level of loss’
I am assuming this fact about memorisation as part of this story. I don’t understand how memorisation is implemented so I cannot be confident, but this seems plausible given that, eg a logistic/linear regression model with enough parameters can memorise random data
In some sense, single epoch training is the limiting case of this. The model tries to memorise every data point it sees, but there must eventually be a cap on how much data it can memorise, so it must pick up on patterns. We never run the model on the same data point twice, so we don’t notice this attempt at memorisation. Plausibly, models that never learn to generalise keep trying to memorise recent training data and forget about old data.
The training run was up to 50K epochs, but there’s a mysterious bonus tiny phase change at epoch 43K that I don’t yet understand, and so I snap the model before then, as I don’t think that shift is central to understand the first grok.
This is useful because, though z=x+y(modp) is an argmax of cos(w(x+y−z)), the difference between it and the second largest value is tiny. While if we take the sum of cos(w(x+y−z)) for several different values of w, the waves constructively interfere at z=x+y(modp), and destructively interfere everywhere else, significantly increasing the difference.
By having a column of WE be (f(0),f(1),...,f(p−1)) - the dot product (ie the first coordinate of the embedding) is f(x). In fact, so long as WE is invertible, we can recover any function f(x) by projecting onto the direction f(x)W−1E
The frequencies here are 14,31,35,41,42,52 - as far as I can tell, the specific frequencies and the number of frequencies are arbitrary, and vary between random seeds
The attention pattern multiplication is the obvious way to do it—I found the use of ReLUs very surprising! But it turns out that ReLU(0.5+cos(kx)+cos(ky)) is approx 0.71+0.695cos(kx)+0.695cos(ky)+0.43cos(kx)cos(ky) (more details)
There is actually a sixth cluster not aligned with any particular frequency—neurons in this cluster always fire, and thus ReLU is the identity. I’m not entirely sure why this is useful to the network
In particular, understanding ReLU in the Fourier basis takes it from a nice basis aligned operation to the operation “rotate out of the Fourier basis into the standard basis, take the elementwise ReLU, and then rotate back into the Fourier basis”
We can validate that the remaining 10% is just noise by ablating—replacing each neuron by this approximation (ie removing all terms of all other frequencies) leaves loss unchanged.
This isn’t quite true—at epoch 43K, there is a tiny phase change which can be seen if you zoom in on the trig loss graph (trig loss and train loss fall, test loss rises). I haven’t investigated this properly, but the model goes from never attending to the = token to significantly attending to it. Presumably there’s some further circuit development here?
Note that excluded loss specifically is weird, because we fit the coefficients of cos(w(x+y)) to the logits across all of the data, so it’s technically also a function of the test data. Probably linear regression on just the train logits would suffice though.
A Mechanistic Interpretability Analysis of Grokking
Link post
A significantly updated version of this work is now on Arxiv and was published as a spotlight paper at ICLR 2023
aka, how the best way to do modular addition is with Discrete Fourier Transforms and trig identities
If you don’t want to commit to a long post, check out the Tweet thread summary
Introduction
Grokking is a recent phenomena discovered by OpenAI researchers, that in my opinion is one of the most fascinating mysteries in deep learning. That models trained on small algorithmic tasks like modular addition will initially memorise the training data, but after a long time will suddenly learn to generalise to unseen data.
This is a write-up of an independent research project I did into understanding grokking through the lens of mechanistic interpretability. My most important claim is that grokking has a deep relationship to phase changes. Phase changes, ie a sudden change in the model’s performance for some capability during training, are a general phenomena that occur when training models, that have also been observed in large models trained on non-toy tasks. For example, the sudden change in a transformer’s capacity to do in-context learning when it forms induction heads. In this work examine several toy settings where a model trained to solve them exhibits a phase change in test loss, regardless of how much data it is trained on. I show that if a model is trained on these limited data with high regularisation, then that the model shows grokking.
One of the core claims of mechanistic interpretability is that neural networks can be understood, that rather than being mysterious black boxes they learn interpretable algorithms which can be reverse engineered and comprehended. This work serves as a proof of concept of that, and that reverse engineering models is key to understanding them. I fully reverse engineer the inferred algorithm from a transformer that has grokked how to do modular addition (which somehow involves Discrete Fourier Transforms and trig identities?!), and use this as a concrete example to analyse what happens during training to understand what happened during grokking. I close with discussion and thoughts on the alignment relevance of these results.
This is accompanied by a paper in the form of a Colab notebook containing the code for this project, a lot of interactive graphics, and much more in-depth discussion and technical details. In this write-up I try to give a high-level conceptual overview of the claims and the most compelling results and evidence, I refer you to the notebook if you want the full technical details.
This write-up ends with a list of ideas for future directions of this research. I think this is a particularly exciting problem to start with if you want to get into mechanistic interpretability since it’s concrete, only involves tiny models, and is easy to do in a Colab notebook. If you might want to work on some of these, please reach out! In particular, I’m looking to hire an intern/research assistant, and if you’re excited about these future directions you might be a good fit.
Key Claims
Grokking is really about phase changes: To exhibit grokking, we train a model on a problem that exhibits phase changes even when given infinite training data, and train it with regularisation and limited data. If we choose the amount of data such that the regularisation only marginally favours the generalising solution over the memorised solution, we see grokking.
Intuition: Regularisation makes the model ultimately prefer the generalising solution to the memorised solution, but the phase change indicates that the generalising solution is “hard to reach” in some sense. The memorised solution is “easier to reach”, and so is reached first. But due to the regularisation, the model still prefers the generalising solution to the memorised solution, and thus gets to the generalising solution eventually—the grokking result just shows that reaching the memorised solution first does not change this, and that there is a path in model space interpolating between the memorising and generalising solution.
The algorithm learned to do modular addition can be fully reverse engineered. The algorithm is roughly:
Map inputs x,y→cos(wx),cos(wy),sin(wx),sin(wy) with a Discrete Fourier Transform, for some frequency w
Multiply and rearrange to get cos(w(x+y))=cos(wx)cos(wy)−sin(wx)sin(wy) and sin(w(x+y))=cos(wx)sin(wy)+sin(wx)cos(wy)
By choosing a frequency w=2πnk we get period dividing n, so this is a function of x+y(modn)
Map to the output logits z with cos(w(x+y))cos(wz)+sin(w(x+y))sin(wz)=cos(w(x+y−z)) - this has the highest logit at z≡x+y(modn), so softmax gives the right answer.
To emphasise, this algorithm was purely learned by gradient descent! I did not predict or understand this algorithm in advance and did nothing to encourage the model to learn this way of doing modular addition. I only discovered it by reverse engineering the weights.
The generalising circuits used to do modular addition can be seen to smoothly evolve over the course of training, as the model interpolates between purely memorising and purely generalising
In particular, the circuits smoothly develop well before grokking, disproving the ‘grokking is a random walk in model-space’ hypothesis.
This is nicely demonstrated by the metric of excluded loss—which roughly shows how much model performance on the training data depends on the generalising algorithm vs memorising algorithm. We see the use of the generalising algorithm to improve training performance rises smoothly over training, well before the grokking point.
Phase Changes
Epistemic status: I feel confident in the empirical results, but the generalisation to non-toy settings is more speculative
Key Takeaways
Phase changes are surprisingly common
Empirically, given a problem with a phase change, training it with weight decay and restricted data gives us grokking
Speculation: Phase changes occur in any circuit that involves multiple components composing.
For example, the composition of the previous token head and the induction head in an induction circuit. The previous token head will only reduce loss if the induction head is there and vice versa. So, initially, the gradients creating each component will be weak to non-existent. But once each component starts to form, the gradients on the other component will become stronger. These effects reinforce each other, and creating a feedback loop that eventually accelerates and results in a phase change.
Intuitive explanation of grokking: Regularisation incentivises the model to be simpler, so the model prefers the generalising solution to the memorised solution. The generalising solution is “hard to reach” and the memorising solution is not, so the memorising solution is reached first. But the incentive to find the generalising solution is still there, so the underlying mechanism to induce the phase change is still going on while grokking—memorisation doesn’t change this
Fundamentally, understanding grokking is about understanding phase changes—I don’t claim to fully understand phase changes or grokking, but I claim to have reduced my confusion about grokking to my confusion about phase changes.
Speculation: Phase changes occur for any specific circuit in a model, and so occur all the time in larger models. Smooth loss curves are actually the average of many small phase changes
I observe several small phase changes in my toy tasks, eg there’s a separate phase change for each digit when learning 5 digit addition.
What Is A Phase Change?
By a phase change, I mean a reverse-S shaped curve[1] in the model’s loss on some dataset, or the model’s capacity for some specific capability. That is, the model initially has poor performance, this performance plateaus or slowly improves, and suddenly the performance accelerates and rapidly improves (ie, loss rapidly decreases), and eventually levels off.
A particularly well-studied motivating example of this is Anthropic’s study of induction heads. Induction heads are a circuit in transformers used to predict repeated sequences of tokens. They search the context for previous copies of the current token, and then attend to the token immediately after that copy, and predict that that subsequent token will come next. Eg, if the 3 token surname
D|urs|ley
is earlier in the context, and the model wants to predict what comes afterD
, it will attend tours
and predict that that comes next.The key fact about induction heads are that there is a fairly narrow band of training time where they undergo a phase change, and go from not existing/functional to being fully functional. This is particularly striking because induction heads are the main mechanism behind how Large Language Models do in-context learning—using tokens from far back in the context to usefully predict the next token. This means that there is a clear phase change in a model’s ability to do in-context learning, as shown below, and that this corresponds to the formation of a specific circuit via a phase change within the model.
My goal here is to convey the intuition of what I mean by a phase change, rather than give a clear and objective definition. At this stage of research, I favour an informal “I know it when I see it” style definition, over something more formal and brittle.
Empirical Observations
Motivation: Modular addition shows clear grokking on limited data[2], but given much more data it shows a phase change in both train and test loss. This motivated the hypothesis that grokking could be observed any time we took a problem with a phase change and reduced the amount of data while regularising:
In several toy algorithmic tasks I observe phase changes (in both test and train loss) when models are trained on sufficient data[3].
The tasks—see the Colab for more details
Modular addition mod 113 (1L full transformer)
Trained on 95% of the training data, 1L Transformer
5 digit addition (1L full transformer)
1L Transformer, data format
1|3|4|5|2|+|5|8|3|2|1|=|0|7|1|7|7|3
Predicting Repeated Subsequences (2L attn-only transformer—task designed to need induction heads)
Data format:
7 2 8 3 1 9 3 8 3 1 9 9 2 5 END
- we take a uniform random sequence of tokens, randomly choose a subsequence to repeat, and train the model to predict the repeated tokens.Finding the max element in a sequence (1L attn-only transformer—task designed to need skip trigrams)
Concretely, the data format is
START 0 4 7 2 4 14 9 7 2 5 3 END
with exactly one entry ≥10, and the model is trained to output that entry after END. The solution is to learn 10 skip trigrams of the form14 .. END 14
Grokking = Phase Changes + Regularisation + Limited Data
For all of the above tasks, we can induce grokking by adding regularisation (specifically weight decay) and limiting the amount of training data. Eg for 5 digit addition on 700 examples (see the notebook for the rest):
With enough data (eg single epoch training), the model generalises easily, and with sufficiently little data (eg a single data point), the model memorises. But there is a crossover point, and we identify grokking when training with slightly more data than the crossover point—analogous to the findings of Liu et al that grokking is an intermediate state between comprehension and memorisation. I found the crossover points (here, 700) by binary searching by hand. Intuitively, this is consistent with the idea that grokking occurs because of regularisation favouring the simpler solution. Memorisation complexity increases with the amount of data (approximately continuously) while generalising complexity does not, so there should must eventually be a crossover point.
This is particularly interesting, because these tasks are a non-trivial extension of the original grokking. Previous demonstrations of grokking had an extremely small universe of data and were trained on a substantial fraction of it, suggesting that the model may be doing some kind of clever interpolation. While here, the universe of data is much larger (eg 1010 possible pairs of 5 digit numbers), and model is trained on a tiny fraction of that, yet it still groks.
Explaining Grokking
Epistemic status: The following two sections are fairly speculative—they’re my current best explanation of my empirical findings, but are likely highly incomplete and could easily be totally wrong.
Speculation: Phase Changes Are Inherent to Composition
I recommend reading the section of A Mathematical Framework for Transformer Circuits on Induction Heads to fully follow this section
To understand the link between phase changes and grokking, it’s worth reflecting on why circuits form at all. A priori, this is pretty surprising! To see this, let’s focus on the example of an induction circuit, and speculate on how it could be formed. The induction circuit is made up of two heads, a previous token head and an induction head, which interact by K-composition. Together, these heads significantly improve loss, but only in the context of the other head being there. Thus, naively, when we have neither head, there should be no gradient encouraging the formation of either head.
At initialisation, we have neither head, and so gradient descent should never discover this circuit. Naively, we might predict that neural networks will only produce circuits analogous to linear regression, where each weight will clearly marginally improve performance as it continuously improves. And yet in practice, neural networks empirically do form sophisticated circuits like this, involving several parts interacting in non-trivial, algorithmic ways.[4]
I see a few different possible explanations for this:
A lottery ticket hypothesis-inspired explanation:[5] Initially, each layer of the network is the superposition of many different partial circuit components, and the output of each layer is the average of the output of each component. The full output of the network is the average of many different circuits. Some of these circuits are systematically useful to reducing loss, and most circuits aren’t. SGD will reinforce the relevant circuits and suppress the useless circuits, so the circuits will gradually form.
A random walk explanation: The network wanders randomly around the loss landscape, until it happens to get lucky and find a half-formed previous token head and induction head that somewhat compose. Once it has this, this half-formed circuit is useful for reducing loss and gradient descent can take over and make a complete circuit.
An evolutionary explanation: There’s a similar mystery for how organisms develop sophisticated machinery like the human eye, where each part is only useful in the context of other parts. The explanation I find most compelling is that we first developed one component that was somewhat useful on its own, eg a light-detecting membrane. This component was useful in its own right, and so was reinforced, and later components could develop that depend on it, eg the lenses in our eye.
The evolutionary explanation is a natural hypothesis, but we can see from my toy tasks that it cannot be the whole story. In the repeated subsequence task, we have a sequence of uniform randomly generated tokens, apart from a repeated subsequence at an arbitrary location, eg
7 2 8 3 1 9 3 8 3 1 9 9 2 5 END
. This means that all pairs of tokens are independent, apart from the pairs of equal tokens in the repeated subsequence. In particular, this means that a previous token head can never reduce loss for the current token—the previous token will always be independent of the next token. So a previous token head is only useful in the context of an induction-like head that completes the circuit. Likewise, an induction head relies on K-composition with a previous token head, and so cannot be useful on its own. Yet the model eventually forms an induction circuit![6]A priori, the random walk story seems unlikely be sufficient on its own—an induction circuit is pretty complicated, and it likes represents a very small region in model space, and so seems unlikely to be stumbled upon by a random walk[7]. Thus my prediction is that the lottery ticket hypothesis is most of what’s going on[8] - an induction head will be useless without a previous token head, but may be slightly useful when composing with, say, a head that uniformly attends to prior tokens, since part of its output will include the previous token! I expect that all explanations are part of the picture though, eg this seems more plausible if the uniform head just so happens to attend a bit more to the previous token via a random walk, etc.
Drawing this back to phase changes, the lottery ticket-style explanation suggests that we might expect to see phase changes as circuits form. Early on in circuit formation, each part of the circuit is very rough, so the effect on the loss of improving any individual component is weak, which means the gradients will be small. But as each component develops, each other component will become more useful, which means that all gradients will increase together in a non-linear way. So as the circuit becomes closer to completion we should expect an acceleration in the loss curve for this circuit, resulting in a phase change.
An Intuitive Explanation of Grokking
With this explanation, we can now try to answer the question of why grokking occurs! To recap the problem setting, we are training our model on a problem with two possible solutions—the memorising algorithm and the generalising algorithm. We apply regularisation and choose a limited amount of data, such that the generalising solution is marginally simpler than the memorising solution[9] and so our training setup marginally prefers the generalising solution over the memorising solution. Naively, we expect the model to learn the generalising solution.
But we are training our model on a problem whose solution involves multiple components interacting to form a complete circuit. So, early in training, the gradients incentivising each component of the generalising solution are weak, because they need the parts to all be formed and lined up properly. Memorisation, however, does not require several components to be lined up in careful and coordinated way[10], so it does not have artificially weak gradients at the start. Thus, at the start, memorisation is incentivised more strongly than generalisation and the model memorises.
So, why does the model shift from memorisation to generalisation? Eventually the training loss plateaus post-memorisation—loss is falling and total weights are rising so eventually the gradients towards lower loss (ie to memorise better) balances with the gradients towards lower weights (ie to be simpler) and cancel out. But they don’t perfectly cancel out. If there’s a direction in model space that allows it to memorise more efficiently[11], then both gradients will encourage this direction. And the natural way to do this is by picking up on regularities in the data—eg, you can memorise modular addition twice as efficiently by recognising that x+y=y+x. This is the same process that leads the model to generalise in the infinite data case—it wants to pick up on patterns in the data.[12]
So the model is still incentivised to reach the generalising solution, just as in the infinite data case. But rather than moving from the randomly initialised model to the generalising model (as in the infinite data case), it interpolates between the memorising solution and the generalising solution. Throughout this process test loss remains high—even a partial memorising solution still performs extremely badly on unseen data! But once the generalising solution gets good enough, the incentive to simplify by deleting the remnants of the memorising solution dominates, and the model clears up the memorising solution, finally resulting in good test performance. Further, as in the infinite data case, the closer we get to the generalising solution the more the rate of change of the loss accelerates. So this final shift happens extremely abruptly—manifesting as the grokking phase change!
As a final bit of evidence, once the model has fully transitioned to the generalising solution it is now inherently simpler, and the point where the incentive to improve loss balances with the incentive to be simpler is marginally lower—we can observe in the graphs here that the model experiences a notable drop in train loss post grokking.
I later walk through this narrative and what it corresponds to in the modular addition case.
Speculation: Phase Changes are Everywhere
My explanation above asserted that phase change are a natural thing to expect with the formation of specific circuits in models. If we buy the hypothesis that most things models do are built up out of many interpretable circuits, then shouldn’t we expect to see phase changes everywhere whenever we train models, rather than smooth and convex curves?
My prediction is that yes, we should, and that in fact we do. But that larger models are made up of many circuits and, though each circuit may form in a phase change, the overall loss is made up out of the combination of many different capabilities (and thus many different circuits). And circuits of different complexity/importance likely form at different points in training. So the overall loss curve is actually the sum of many tiny phase changes at different points in training, and this overall looks smooth and convex. Regularities in loss curves like scaling laws may be downstream of statistical properties of the distribution of circuits, which become apparent at large scales. We directly observe the phase change-y ness from the loss curves in these simple toy problems because the problems are easy enough that only one/a few circuits are needed.
Some evidence for this hypothesis:
5 digit addition (toy problem) - We can decompose the loss into the sum of 6[13] components—the loss on each of the 6 digits in the sum. When we do this, we observe separate phase changes in each digit, resulting in the many small non-convexities in the overall loss curve.
The ordering of phase changes is not stable between runs, though token 0 and token 1 tend to be first[14]
This isn’t specific to 5 digit, eg 15 digit addition shows 16 separate phase changes
Skip Trigrams (toy problem) - The model learns 10 different skip trigrams,
10 .. END 10
,11 .. END 11
, etc. Each skip trigram shows a separate phase changeNotably, each phase change happens at approximately the same time, so the overall curve looks less bumpy than 5 digit addition. This makes sense, because each skip trigram is “as complex” as the others, while learning to add some digits is much harder than others.
Induction Heads—Induction heads are the best studied example of a specific circuit through training, and there we see a clear phase change in LLMs up to a 13B transformer. Each head should be “as complex” as the others, so it makes sense that the all occur at approximately, but not exactly, the same time.[15]
Emergent phenomena—Emergent phenomena are a widely seen thing as language models scale up—they abruptly go from incapable to capable on a task such as addition. As argued by Jacob Steinhardt, More is Different in AI. It seems natural to extend this to a particular model during training, as we know that smaller models trained on more data can outcompete larger models on less data
AlphaZero Capabilities: One finding in DeepMind’s AlphaZero Interpretability paper was that there is a phase change in the model’s capabilities, where it learns to represent a lot of chess concepts around step 32,000.
Summary
I think this is highly suggestive evidence that there is a deep relationship between grokking and phase changes, and that grokking occurs when models with a phase change are trained with regularisation and limited data. I present some compelling (to me) explanations of what might be behind the phase change behaviour, and show how this model explains grokking and predicts several specific empirical observations about grokking. I don’t claim to fully understand phase changes or grokking, but I do claim to have substantially reduced my confusion about grokking to my confusion about phase changes.
Modular Addition
Epistemic status: I feel pretty confident that I have fully reverse engineered this network, and have enough different lines of evidence that I am confident in how it works. My explanation of how and why it develops during training is shakier.
Key Takeaways
A 1L Transformer learns a clear and interpretable algorithm to do modular addition.
This algorithm operates via using trig identities and Discrete Fourier Transforms to map x,y→cos(w(x+y)),sin(w(x+y)), and then extracting x+y(modp)
This algorithm can be clearly read off from the weights. If we apply a Discrete Fourier Transform to the input space and apply the Transformer Circuits framework, the structure of the network and its resulting algorithm becomes clear.
The model naturally forms several sub-networks that calculate the sumncos(w(x+y−z)) in different frequencies and add these to form the logits. This can be seen by a clear clustering of the neurons
Within a cluster, individual neurons clearly represent interpretable features for a single frequency.
To emphasise, this algorithm was discovered purely via gradient descent, not by me. I didn’t think of this algorithm until I reverse engineered it from the weights!
The evolution of this algorithm can be clearly seen during training, and systematic progress towards the generalising circuit can be seen well before the grokking point
We can see each phase of my narrative of grokking manifest during training
With excluded loss, we can see the model interpolate between memorisation and generalisation. Train loss performance depends substantially on the generalising circuit well before we see a significant change in test loss.
Model Details
In this section I dive deeply into one specific and well-checkpointed model trained to do modular addition. See model training code for more details, but here are the key points:
The model is trained to map x,y to z≡x+y(mod113) (henceforth 113 is referred to as p)
1L Transformer
Learned positional embeddings
Width 128
No LayerNorm
ReLU activations
Input format is
x|y|=
, where x,y are one-hot encoded inputs, and = is an extra token.Trained with AdamW, with weight decay 1 and learning rate 10−3
Full batch training, trained on 30% of the data (ie the 1132 pairs of inputs) for 40,000 epochs[16]
This is a 1L transformer, with no LayerNorm and learned positional embeddings, trained with AdamW with weight decay 1, and full batch training on 30% of the data (the data is the 1132 pairs of numbers modp). The
Overview of the Inferred Algorithm
The key feature of the algorithm is calculating cos(w(x+y)),sin(w(x+y)) with w=2πpk - this is a function of x+y and be mapped to x+y, and because cos(wx) has period pk we get the (modp) part for free.
More concretely:
Inputs x,y are given as one-hot encoded vectors in Rp
Calculates cos(wx),cos(wy),sin(wx),sin(wy) via a Discrete Fourier Transform (This sounds complex but is just a change of basis on the inputs, and so is just a linear map)
w=2πpk, k is arbitrary, we just need period dividing p
Calculates cos(wx)cos(wy),cos(wx)sin(wy),sin(wx)cos(wy),sin(wx)sin(wy) by multiplying pairs of waves in x and in y
Calculates cos(w(x+y))=cos(wx)cos(wy)−sin(wx)sin(wy) and sin(w(x+y))=sin(wx)cos(wy)+cos(wx)sin(wy) by rearranging and taking differences
Calculates cos(w(x+y−z))=cos(w(x+y))cos(wz)+sin(w(x+y))sin(wz) via a linear map to the output logits z
This has an argmax at z≡x+y(modp), so post softmax we’re done!
There are a few adjustments to implement this algorithm in a neural network:
The model’s activations at any point are vectors(/tensors). To represent several variables, such as cos(wx),sin(wx), these are stored as different directions in activation space. When the vector is projected onto those dimensions, the coefficient is the relevant variable (eg cos(wx))
The model runs the algorithm in parallel for several different frequencies[17] (different frequencies correspond to different clusters of neurons, different subspaces of the residual stream, and sometimes different attention heads)
Background on Discrete Fourier Transforms
A key technique in all that follows is Discrete Fourier Transforms (DFT). I give a more in-depth explainer in the colab, but here’s a rough outline—I expect this requires familiarity with linear algebra to really get your head around. The key motivating observation is that most activations inside the network are periodic and so techniques designed to represent periodic functions nicely are key. Eg the attention patterns:
In Rp, we have a standard basis of the p unit vectors. But we can also take a basis of p cosine and sine waves, F∈Rp×p, where F0=(1,1,...,1) is the constant vector, and F2k−1=cos(2πpkx) and F2k=sin(2πpkx) are the cosine and sine wave of frequency w=2πpk (henceforth referred to as frequency w=k and represented as coskx for brevity) for k=1,...,p−12. Every pair of waves has dot product zero, unless they’re the same wave (ie it’s an orthogonal basis). If we normalise these rows, we get an orthonormal basis of cosine and sine waves (so F−1=FT). We refer to these normalised waves as Fourier Components and this overall basis as the 1D Fourier Basis.
We can apply a change of basis to the 1D input space Rp to F, and this turns out to be a much more natural way to represent the input space for this problem, as the network learns to operate in terms of sine and cosine waves. Eg, the fourth column of WEFT is the direction corresponding to sin2x in the embedding for WE. If we apply this change of basis to both the input space for x and for y we apply a 2D DFT, and can represent any function as the linear combination of terms of the form sinw1xcosw2y (or cos(w1x)cos(w2y), Const∗cos(w2y), etc). This is just a change of basis on Rp×p=Rp2, and terms of the form sinw1xcosw2y (ie the outer product of each pair of rows in F) form an orthogonal basis of p2 vectors (henceforth referred to as the 2D Fourier Basis).
Importantly, this is just a change of basis! If we choose any single activation in the network, this is a real valued function on pairs of inputs x,y∈Rp×p, and so is equivalent to specifying a p2 dimensional vector. And we can apply an arbitrary change of basis to this vector. So we can always write it as a linear combination of terms in the 2D Fourier Basis. And any vector of activations is a linear combination of 2D Fourier terms times fixed vectors in activation space. If a function is periodic, this means that it is sparse in the 1D or 2D Fourier Basis, and this is what tells us about the structure of the algorithm and weights of the network.
Reverse Engineering the Algorithm
Here, I present a case for how I was able to reverse engineer the algorithm from the weights. See the Colab and appendices (attention and neuron) for full details, my goal in this section is to roughly sketch what’s going on and why I’m confident that this is what’s going on.
Calculating waves cos(wx),sin(wx),cos(wy),sin(wy)
Relevant notebook section
Theory: Naively, this seems like the hard part, but is actually extremely easy. The key is that we just need to learn the discretised wave on x∈[0,1,...,p−1], not for arbitrary x∈R. x is input into the network as a one-hot encoded vector, and the multiplied by a learned matrix WE. We can in fact learn any function f:[0,1,...,p−1]→R[18]
Conveniently, F, the matrix of waves cos(wx),sin(wx), is an orthonormal basis. So WEFT will recover the direction of the embedding corresponding to each wave Const,cosx,sinx,cos2x,... - in other ways, extracting cos(wx),sin(wx) is just a rotation of the input space.
Evidence: We can use the norm of the embedding of each wave to get an indicator of how much the network “cares” about each wave[19], and when we do this we see that the plot is extremely sparse. The model has decided to throw away all but a few frequencies[20]. This is very strong evidence that the model is working in the Fourier Basis—we would expect to see a basically uniform plot if this was not a privileged basis.
Calculating 2D products of waves cos(wx)cos(wy),cos(wx)sin(wy),sin(wx)cos(wy),sin(wx)sin(wy)
Relevant notebook section
Theory: A good mental model for neural networks is that they are really good at matrix multiplication and addition, and anything else takes a lot of effort[21]. As so here! As we saw above, creating cos(wx),sin(wx) is just a rotation, and the later rearranging and map to the logits is another linear map, but multiplying the terms together is hard and non-linear.
There are three non-linear operations in a 1L transformer—the attention softmax, the element-wise product of attention and the value vectors, and the ReLU activations in the MLP layer. Here, the model uses both ReLU activations and element-wise products with attention to multiply terms[22].
The neurons form 5[23] distinct clusters for each frequency, and each neuron in the cluster for frequency w has its activation as a linear combination of 1,cos(wx),sin(wx),cos(wy),sin(wy),cos(wx)cos(wy),cos(wx)sin(wy),sin(wx)cos(wy),sin(wx)sin(wy).[24] Note that, as explained above, the neuron activation in any network can be represented as a linear combination of products of Fourier terms in x and Fourier terms in y (because they form a basis of Rp×p). The surprising fact is that this representation is sparse! This can be visually seen as neuron activations being periodic:
Evidence: The details of how the terms are multiplied together are highly convoluted[25], and I leave them to the Colab notebook appendices. But the neurons do in fact have the structure I described, and this can be directly observed by looking at their values. And thus, by this point in the network it has computed the product terms.
For example, the activations for neuron 0 (as plotted above) are approximately 109−39(cos42x+cos42y)−76(sin42x+sin42y)+36(cos42xsin42y+sin42xcos42y)−10cos42xcos42y+38sin42xsin42y (these coefficients can be calculated by mapping the neuron activation into the 2D Fourier Basis). This approximation explains >90% of the variance in this neuron[26]. We can plot this visually with the following heatmap:
Zooming out, we can apply a 2D DFT to all neuron activations, ie writing all of the neuron activations as a linear combinations of terms of the form cos42xcos42y times vectors, and plotting the norm of each vector of coefficients. Heuristically, this is telling us what terms are represented by the network at the output of the neurons. We see that the non-trivial terms are in the top row (of the form coswx,sinwx) or the left column (of the form coswy,sinwy) or in a cell of 2D cells along the diagonal (of the form cos(wx)cos(wy),cos(wx)sin(wy),sin(wx)cos(wy),sin(wx)sin(wy) - notably, a product term where both terms have the same frequency).
Calculating cos(w(x+y)),sin(w(x+y)) and calculating logits
Relevant notebook section
Theory: The operations mapping cos(w(x+y))=cos(wx)cos(wy)−sin(wx)sin(wy) and sin(w(x+y))=cos(wx)sin(wy)+sin(wx)cos(wy) are linear, and the operations mapping this to cos(w(x+y−z))=cos(w(x+y))cos(wz)+sin(w(x+y))sin(wz) are also linear. So their composition is linear, and can be represented by a single matrix multiplication. The neurons are mapped to the logits by L=WUWoutN, and so the effective weight matrix Wlogit=WUWout must represent both of these operations (if my hypothesis is correct). Note that Wlogit is a p×dmlp matrix, mapping from MLP-space to the output space.
Evidence: We draw upon several different lines of evidence here.
We show that the terms cos(w(x+y)),sin(w(x+y)) are computed as follows: We repeat the above analysis to find terms represented by the neurons on the logits, we find that the terms in the top row and left column cancel out. This leaves just diagonal terms, corresponding to products of waves of the same frequency in x and y, exactly the terms we need. We also see that the 2x2 blocks are uniform, showing that cos(w(x+y)) and sin(w(x+y)) have the same coefficient. Further analysis shows that everything other than cos(w(x+y)),sin(w(x+y)) for these 5 frequencies is essentially zero.
We now show that the output logits produce cos(w(x+y−z))=cos(w(x+y))cos(wz)+sin(w(x+y))sin(wz) for each of the 5 represented frequencies (where z is the term corresponding to the output logits). The neurons form clusters for each frequency, and when we plot the columns of Wlogit corresponding to those frequencies, and apply a 1D DFT to the output space of Wlogit, we see that the only non-trivial terms are cos(wz),sin(wz) - ie the output logits coming from these neuron clusters is a linear combination of cos(wz),sin(wz).
We can more directly verify this by writing approximating the output logits as a sum of ∑w∈[14,35,41,42,52]Awcos(w(x+y−z)) and fitting the coefficients Aw. When we do this, the resulting approximated logits explains 95% of the variance in the original logits. If we evaluate loss on this approximation to the logits, we actually see a significant drop in loss, from 2∗10−7 to 4.7∗10−8
Evolution of Circuits During Training
Note: For this section in particular, I recommend referring to the Colab! That contains a bunch of interactive graphics that I can’t include here, where we can observe the development of circuits during training.
Now that we understand what the model is doing during training on a mechanistic level, we can directly observe the development of these circuits during training. The key observation is that the circuits develop smoothly, and make clear and systematic progress towards the generalising circuit well before there is notable test performance. For example, the graph of the norms of the embedding of different waves becomes sparse. In particular, this disproves the natural hypothesis that grokking is the result of a random walk on the manifold of models that perform optimally.
I will now give and explain several graphs of metrics during training, and use them to support my above narrative of what happens during grokking in this specific case:
Sum of squared weights—the main regularisation is weight decay, so paying attention to the change in sum of squared weights is key to understand the effect of the regulariser
This uses the transformer circuits notation - Win,Wout We canare MLP weights, WU is the unembed, WE is the embed, WQ,WK,WO,WV are the attention weights
Excluded loss: Intuitively, this metric tracks how much the model’s performance on the training data comes from the generalising algorithm vs the memorisation algorithm, by ablating the model’s ability to use a specific frequency w. This serves as a progress measure, quantitatively tracking the model’s progress towards the correct algorithm.
Formally, we calculate the logits, use a 2D DFT to write them as a linear combination of waves in x and waves in y, and set the terms corresponding to cos(w(x+y)),sin(w(x+y)) to zero and keep the other terms the same. Then measure the train loss.
If the model is using the addition identities (ie, with the generalising algorithm), this will heavily damage performance. If it isn’t (ie, the memorising algorithm), then this is just deleting 2 directions among 1132, and so should have no effect.
Trig loss: Another metric is to take the output logits, and only allow them to use the directions corresponding to cos(w(x+y)),sin(w(x+y)) for the 5 relevant frequencies. Intuitively, this represents the model’s performance with the generalising algorithm preserved but the memorising algorithm removed.
Here we evaluate performance on both the train and test set—trig loss for the train and test set is almost equal (the lines are identical), showing that it’s purely tracking the generalising algorithm
We can now use these metrics to analyse the separate stages of training:
Memorisation (Epoch 0-1K): The model memorises. We see a dramatic drop in train loss, and a rise in test loss.
We see that the incentive to memorise is far stronger than the incentive to have low weights, as total weights rise a lot
We see that it’s memorising rather than generalising as excluded loss is very close to train loss
Interpolation (Epoch 1K-9K): The model is interpolating between memorisation and generalisation.
We see this in the smooth divergence of excluded loss from train loss—this shows that the model is deriving more and more of its training performance from the generalising algorithm.
Intuitively, the model is trying to memorise “more efficiently” by picking up on patterns in the data. Some directions in neuron space are functions of x+y(modp), namely cos(w(x+y)),sin(w(x+y)) terms, so it learns to use these directions to represent more and more information.
Train loss mostly plateaus in this period—this shows that although the model is changing its underlying algorithm, it does this by interpolating between the purely memorising circuit and the purely generalising circuit
It surprises me that the model can interpolate between two circuits while retaining good performance!
This stage begins when the incentive for low train loss and the incentive for low weights balance, and train loss abruptly levels out.
This is particularly clear from epoch 1K-3K, when the weights for the MLP layer Win,Wout drop in norm and meet the other weight matrices. Given a matrix product AB, we can always scale up A→kA and scale down B→1kB and keep the product AB the same. The sum of squared weights is minimised when |A|=|B|, so all weight matrices between the embed and unembed should have equal norm—this minimises total weights and keeps loss the same. The fact that this does not happen in the first stage shows that the incentive for small weights is too weak and has little effect
Further, this is suggestive evidence that most of the memorisation comes from changes in the MLP—this matches the hypothesis that memorisation involves “simpler” circuits that don’t involve complex composition, and so doesn’t need a phase change.
Generalisation (Epoch 8K-12K): We see a phase change in trig loss. Trig loss represents “model performance if we delete the memorisation circuit”, so this shows that the generalising circuit rapidly develops in a phase change.
This matches my (admittedly post-hoc) hypothesis that all circuits see phase changes—as each component of the circuit forms, it reinforces the gradients on all other components.
We further see a rapid acceleration in the decrease in total weights, supporting the idea that there is some cascade as the circuit forms.
Cleaning Up Noise (Epoch 9K-13K): Somewhat in parallel with the above phase change in trig loss, there is a slightly delayed phase change in test loss. The slight delay indicates that good test loss performance requires both a strong generalising circuit, and the removal of the memorising circuit.
The rapid removal of the memorisation circuit presumably comes from weight decay—once the memorisation circuit’s contribution to loss is low enough, the total weight benefits of removing it outweigh the contribution to loss.
Stability (Epoch 13K-end)[27]: The model has now transitioned to using the generalising circuit and all our metrics settle down
The model still has some remnants of the memorisation circuit, as train loss is still less than test loss. This must come from directions other than cos(w(x+y)),sin(w(x+y)) in logit space, as trig loss is identical on the train and test data.
Discussion
Alignment Relevance
Ultimately, my goal is to do research that is relevant to aligning future powerful AI systems. Here I’ve studied tiny models
(Note—the true reason I did this project was that I was on holiday, and got extremely nerd-sniped, so this is all post-hoc reasoning!)
Emergent capabilities:
Is grokking relevant to large models? Probably not—models like GPT-3 and PaLM are trained on functionally infinite data, and rarely see a data point twice. In fact, there’s a strong incentive to avoid this, as repeated data can severely damage performance.
The main way I think this is directly relevant to large models is the indication that phase changes are everywhere. If most circuits emerge as phase changes during training, this suggests we should expect emergent capabilities to be a frequent occurence. In particular, this may occur for particular dangerous capabilities, such as deception or situational awareness
This would make alignment a harder problem, as it is less likely that we’ll see warning shots. It also limits our ability to, eg, continuously test a system for misalignment or deception throughout training. This strategy might work if the system will first try to be deceptive badly before it is superhuman at lying, but will fail if the system rapidly jumps from not even trying to be deceptive to being superhuman.
However, this is highly speculative—even if this hypothesis holds, these capabilities may be implemented by many different circuits, and so still see gradual development
A laboratory for training dynamics + generalisation: There are many open questions about exactly why neural networks work, why they learn to generalise to unseen data, what kinds of generalising algorithms they learn, and how this can be influenced. These are extremely hard questions to study, and I don’t feel that I’ve understood them as a result of this work. But I think the modular addition transformer may be a good laboratory to examine these questions, as it’s a concrete, small model that can be easily trained and whose circuits can be clearly inspected. Further, understanding these questions is a key question for alignment—aligning systems will likely require predicting when they would learn to be deceptive vs aligned and how to influence them.
A particularly important direction may be directly training systems on interpretability-based metrics. Naively, if we could create an automated tool to detect deception and put this in the loss function, our system may never learn to be deceptive. But by default, I would expect this kind of strategy to fail badly, because gradient descent will just find a way to Goodhart the tool. I expect a good (though still highly limited) in-road into this question would be to train the modular addition system on interpretability-based metrics, and see the effects on the training dynamics and the resulting circuits.
The ability to understand training dynamics rather than just a final system is also notable here. Possibly, if we ever end up with a system that is deceptive and aware of its current situation as a neural network in training, it will be able to influence the training process to always keep it misaligned, and evade any interpretability techniques trying to detect this. The only way to avoid this will be to influence the training dynamics such that the system never becomes deceptive and situationally aware in the first place.
Importance of interpretability: A high-level motivation behind this work is a belief that mechanistic interpretability is key to understand how neural networks work. They seem to learn interpretable circuits, and these determine both the model’s capabilities, and how those capabilities are implemented. It seems very difficult to get traction on what’s really going on in models without engaging with these circuits—analogous to doing chemistry without discovering the periodic table. I take this work as a proof of concept of that—I took the weird and confusing phenomena of grokking, and feel that I have significantly de-mystified it by fully reverse engineering the underlying circuits.
This is an admittedly very toy problem, but feels like evidence supporting my belief that interpretability extremely useful for getting alignment right.
This is a particularly notable with grokking, as here we studied the competition between two circuits with comparable loss. The standard way to predict what a model will learn is by expecting it to have low loss, but this breaks down if multiple mechanistically different solutions have comparable performance. This is extremely relevant for alignment as, for example, a deceptive system will be incentivised to output exactly what we’d expect from an aligned solution, and so get the same loss.
This work also reinforces my belief that the circuits agenda is possible at all—neural networks seem to want to be understood, if we can just find the right way to reverse engineer them! In particular, the fact that I discovered an unexpected and new[28] algorithm for doing modular addition via reverse engineering.
Limitations
There are several limitations to this work and how confidently I can generalise from it to a general understanding of deep learning/grokking, especially the larger models that will be alignment relevant. In particular:
The modular addition transformer is a toy model, which only needs a few circuits to solve the problem fully (unlike eg an LLM or an image model, which needs many circuits and may lead to different behaviour)
The model is over-parametrised, as is shown by eg it learning many redundant neurons doing roughly the same task. It has many more parameters than it needs for near-perfect performance on its task, unlike real models.
In particular, models like language models could always learn more circuits to achieve marginally better task performance, and so need to use their parameters efficiently. This may act as a form of implicit regularisation.
I only study weight decay as the form of regularisation.
My explanations of what’s going on rely heavily on fairly fuzzy and intuitive notions that I have not explicitly defined: crisp circuits, simplicity, phase changes, what it means to memorise vs generalise, model capacity, etc.
Of my four tasks exhibiting phase changes, I have only fully interpreted modular addition.[29]
Future Directions
I see a lot of future directions building on this work. If you’d be interested in working on any of these, please reach out! I’d love to chat and possibly advise or collaborate on a project. I’m also currently looking to hire an intern/research assistant to help me work on mechanistically interpretability projects (including but not limited to these ideas), and if you feel excited about any these directions you may be a good fit.
Investigating the ‘phase changes are everywhere’ hypothesis.
Focusing on larger models on complex tasks
Training a larger model that we know how to interpret, making many checkpoints and inspecting circuits during training
Training a small SoLU transformer (1L is particularly easy to interpret)
Training Inception and looking for curve circuits
Looking for emergent capabilities in open source well-checkpointed training runs (Eg Mistral and GPT-J)
Performance on benchmarks or specific questions from a benchmark
Simple algorithmic tasks—eg few shot learning, addition, sorting words into alphabetical order, completing a rhyming couplet, matching open and close brackets
Soft induction heads, eg translation
Look at attention heads on various texts and see if any have recognisable attention patterns (eg start of word, adjective describing current word, syntactic features of code like indents or variable definitions, most recent open bracket, etc).
Wei et al is a good source for other ideas
More elaborate toy problems that involve several different tasks
Eg, training a model with separate heads for modular addition, subtraction, multiplication, etc. Do we see separate phase changes?
Training models for the other tasks in the grokking paper with more data, and seeing if we see phase changes but no grokking.
Investigating the ‘phase changes are inherent to composition’ hypothesis
If we directly look at the gradients on different components of the circuit, do we see this mutual reinforcement and cascading?
Can we show this analytically in some toy setting? What are the minimal properties we need to see phase changes?
I predict the most tractable will be skip trigrams, given how simple they are. We can take even simpler versions, by making the only parameters a QK and OV matrix, removing the attention softmax, etc.
Are phase changes specific to weirdnesses of the Adam optimizer? Do we get them for SGD or Momentum?
Note: Once a model has memorised, loss can be around 1e-7. SGD gradients scale with loss while Adam is normalised, so you’ll need to increase your learning rate a lot.
Barak et al may be another good setting to analyse—looking for phase changes in the k-bit parity problem
Loose threads in this project
Interpreting my other tasks with phase changes, and examining the circuits
Preliminary investigation shows that 5 digit addition is using a variant of the trig based algorithm
Interpreting the memorisation circuit in the modular addition transformer—understanding memorisation in transformers seems exciting to me!
This is likely easiest on random labels with no structure.
Why is there another phase change at epoch 43K?! What’s the algorithm after this?
There’s some circuit competition between eg frequency 25 and 31 from epoch 12K to 14K, or between multiplying terms with the attention patterns or the ReLUs. What’s up with that?
Why do the given frequencies form, and not others? Why 5 or 6 frequencies, and not more or less?
Train the model for a bunch of different random seeds, and see how consistent these observations are.[30]
Training dynamics. Note that a 2L MLP can also grok modular addition, and may be easier to study.
What happens when we train on interpretability-inspired metrics like excluded loss[31]? Can we accelerate the grokking point? Can we disincentivise it? Can we incentivise or disincentivise specific frequencies?
How do forms of regularisation other than weight decay affect the model and its inductive biases.
How and why does the model learn to form separate clusters of neurons? Can we predict neuron clustering in advance? Can we manually shift the initialisation to change the clustering?
Can we replicate a lottery ticket-style approach here? What subnetworks are useful for solving this problem? And if only some subnetworks are useful, why do all neurons ultimately end up being used?
How general is this algorithm? Can we find Fourier-style algorithms in LLMs? Especially when they learn to do addition
Note, I expect this to be much easier to detect in models that tokenize each digit separately—the GPT-3 tokenizer completely breaks place value:
28|79|23|598|23|45|987|249|234|0000|23|47|89|03|00000|700|9|02|10
. Sadly, the only tokenizer I know of like this is PaLM’s.The GPT-3 tokenizer mostly creates 2 digit tokens, so I’d start by trying to use a linear probe to extract cos(w(z)),sin(w(z)) in base 100 from the model for each frequency, where z is the correct output digit (or possible the output if you incorrectly carry a 1)
Getting into the field: I think this problem is notable for being fairly simple and involving small models that it’s tractable to completely understand. If you’re interested in getting into mechanistic interpretability and are new to the field, I think this may be a great place to start! As a concrete first project, I recommend training your own transformer for modular subtraction, figuring out the analogous algorithm it should be using, and see if you can replicate my analysis to reverse engineer that.
Meta
Notable Observations
Some scattered thoughts that feel exciting to me, but didn’t naturally fit into this write-up:
Slingshot mechanism: (see colab) The Slingshot Mechanism was a fascinating paper from Thilak et al, that found that grokking seemed tied to ‘slingshots’, dramatic loss spikes, that it was hypothesised flung the model to different regions in the loss landscape.
I found that these slingshots were not necessary to achieve grokking[32], and further that the slingshots are the result of a float32 precision error in
log_softmax
.This can be fixed by casting your logits to float64 before taking
log_softmax
. If I keep my training code exactly the same but remove this casting, my loss curve instead looks like this monstrosity:Explanation: Concretely, the smallest non-zero value that
log_softmax
can output in float32 is1.19e-7
. So once the model has memorised sufficiently well that loss approaches the scale of1e-7
, some data points will have loss rounding to zero. This also rounds the gradients of their loss to zero, equivalent to them not being in the training set. When a model memorises data, it wants to repurpose all other parameters and set unseen data points to extreme values, and this precision error means this includes some training data points. This introduces a dodgy gradient, and the EWMA of Adam ensures it persists for long enough for a significant loss spikeNote that this is an inherent problem of float32, not a flaw in log_softmax. For small values of x,
log_softmax(x)
is equal tolog(1+x)
, and this is approximatelyx
.1.19e-7
is the smallest value such that1+x!=1
in float32.Float32 isn’t the whole story. With minibatch training and float64, we also see loss spikes, though rarer and less severe
Weight Decay: Naively, for a fixed training dataset size, the model should only grok for sufficiently high weight decay. But this is not actually true! Lower weight decay just delays grokking rather than stopping it.
Intuition: The model will have an incentive to switch to the generalising circuit so long as it has lower total weight than the memorising circuit. Importantly, this holds for any non-zero weight decay. Smaller weight decay just leads to a smaller incentive
Explaining circle plots of WE: A notable finding of Liu et al and the original grokking paper is that the embedding matrix of modular addition can be plotted as a circle, by the first two principal components or t-SNE respectively. Now that we understand the underlying circuits, we can see that this clearly follows from the fact that the embedding matrix has orthogonal rows for sin(wx),cos(wx) - these dimensionality reduction techniques are just finding two directions of the same frequency, and observing that a scatter plot of sin(wx),cos(wx) is a circle
2L MLPs can grok modular addition: A surprising (to me) result from Liu et al is that a 2L MLP can grok modular addition[33]. I replicated this result, and found that it learned the same algorithm (here, purely using the ReLU/GeLU activations to multiply the waves in x and y). The neurons formed a cluster for every single frequency, with 3-7 neurons per cluster (width 256), and the input and output weights were sparse in the 1D Fourier basis.
Acknowledgements
This work was done as independent research after leaving Anthropic, but benefitted greatly from my work with the Anthropic interpretability team and skills gained, especially from an extremely generous amount of mentorship from Chris Olah. It relies heavily on the interpretability techniques and framework developed by the team in A Mathematical Framework for Transformer Circuits, and a lot of the analysis was inspired by the induction head bump result.
Thanks to Noa Nabeshima for pair programming with me on an initial replication of grokking, and to Vlad Mikulik for pair programming with me on the grokked induction heads experiment.
Thanks to Jacob Hilton, Alex Ray, Xander Davies, Lauro Langosco, Kevin Wang, Nicholas Turner, Rohin Shah, Vlad Mikulik, Janos Kramar, Johannes Treutlein, Arthur Conmy, Noa Nabeshima, Eric Michaud, Tao Lin, John Wentworth, Jeff Wu, David Bau, Martin Wattenberg, Nick Cammarata, Sid Black, Michela Paganini, David Lindner, Zac Kenton, Michela Paganini, Vikrant Varma, Evan Hubinger (and likely many others) for helpful clarifying discussions about this research, feedback, and helping me identify my many ill-formed explanations!
Generously supported by the FTX regrantor program
Author Contributions
Tom Lieberum significantly contributed to the early stages of this project—replicating grokking for modular subtraction, discovering it was possible in a 1L transformer with no LayerNorm, and observing the strongly periodic behaviour of activations and weights
Neel Nanda led the rest of this project—fully reverse engineering the modular addition transformer, analysing it during training, and discovering and analysing the link between grokking and phase changes. He wrote this write-up, and it is written from his perspective.
Feedback
I’d love to hear feedback on this work—parts you find compelling, parts you’re skeptical of, parts I explained poorly, cases for why Colab notebooks are a terrible format for papers, etc. You can comment here, or reach me at neelnanda27@gmail.com.
Citation Info
Please cite this work as:
In the ideal case this looks flat, then drops abruptly, then is flat again. In practice, I’m happy to include any striking non-convexities in a curve.
Ie a divergence in train and test loss
Where possible, I train the models on infinite data (ie, the data is randomly generated each time, and there are enough possible data points that nothing gets repeated) - in this case I do not distinguish between train and test loss, as train loss is always on unseen data.
This is even more mysterious when you consider how each head has both an attention (QK) circuit and a value (OV) circuit. The previous token head’s QK circuit needs to compose with the positional embeddings and its OV circuit needs to compose with the embedding’s output. The induction head’s QK circuit needs to compose with the previous token head’s output on the K-side, and the embedding on the Q-side, and the OV circuit needs to compose with the embedding and unembedding to be the identity. This is a lot of synchronised and non-trivial interaction! But we know that real networks form circuits with as many moving parts as this, or more, eg curve detecting circuits.
Eric Michaud gives some evidence for this by finding that, if we project the model’s initialised embedding onto the principal components of the final embedding, we still see a circle-like thing—suggesting that the model reinforced directions in embedding space that luckily had good properties.
I have not verified that the model actually forms an induction circuit of the form described in this particular task. But the structure of attention means it must be doing some kind of composition, because a single attention head cannot attend to a previous token based on its neighbours
I later show in the case of modular addition that it is definitively not a random walk, as we can see clear progress towards the true answer before grokking
Though I’m sure I’m missing other explanations!
This is a more complex statement than it looks, since each solution can always become more complex to result in lower loss. Eg doubling WU will always decrease loss if there’s 100% accuracy. I operationalise this as ‘which solution is simpler if we set them both to achieve the same level of loss’
I am assuming this fact about memorisation as part of this story. I don’t understand how memorisation is implemented so I cannot be confident, but this seems plausible given that, eg a logistic/linear regression model with enough parameters can memorise random data
ie, more simply/with lower total weights
In some sense, single epoch training is the limiting case of this. The model tries to memorise every data point it sees, but there must eventually be a cap on how much data it can memorise, so it must pick up on patterns. We never run the model on the same data point twice, so we don’t notice this attempt at memorisation. Plausibly, models that never learn to generalise keep trying to memorise recent training data and forget about old data.
The 0th digit is the leading digit that’s 0 or 1
This is weird! Token 5 (the final digit) should be easiest because it never requires you to carry a 1 - it’s just addition mod 10
Under this hypothesis, I would predict that soft induction heads like translation appear after the phase change, as they seem likely more complex
The training run was up to 50K epochs, but there’s a mysterious bonus tiny phase change at epoch 43K that I don’t yet understand, and so I snap the model before then, as I don’t think that shift is central to understand the first grok.
This is useful because, though z=x+y(modp) is an argmax of cos(w(x+y−z)), the difference between it and the second largest value is tiny. While if we take the sum of cos(w(x+y−z)) for several different values of w, the waves constructively interfere at z=x+y(modp), and destructively interfere everywhere else, significantly increasing the difference.
By having a column of WE be (f(0),f(1),...,f(p−1)) - the dot product (ie the first coordinate of the embedding) is f(x). In fact, so long as WE is invertible, we can recover any function f(x) by projecting onto the direction f(x)W−1E
Remember that we’re using high weight decay, so the model will set any weights it’s not using to (almost) zero
The frequencies here are 14,31,35,41,42,52 - as far as I can tell, the specific frequencies and the number of frequencies are arbitrary, and vary between random seeds
But is also extremely necessary! Eg how activation functions take an MLP from linear regression to being able to represent any function.
The attention pattern multiplication is the obvious way to do it—I found the use of ReLUs very surprising! But it turns out that ReLU(0.5+cos(kx)+cos(ky)) is approx 0.71+0.695cos(kx)+0.695cos(ky)+0.43cos(kx)cos(ky) (more details)
The model calculates 6 frequencies with the embedding, but only 5 with the neurons—I’m not sure why.
There is actually a sixth cluster not aligned with any particular frequency—neurons in this cluster always fire, and thus ReLU is the identity. I’m not entirely sure why this is useful to the network
In particular, understanding ReLU in the Fourier basis takes it from a nice basis aligned operation to the operation “rotate out of the Fourier basis into the standard basis, take the elementwise ReLU, and then rotate back into the Fourier basis”
We can validate that the remaining 10% is just noise by ablating—replacing each neuron by this approximation (ie removing all terms of all other frequencies) leaves loss unchanged.
This isn’t quite true—at epoch 43K, there is a tiny phase change which can be seen if you zoom in on the trig loss graph (trig loss and train loss fall, test loss rises). I haven’t investigated this properly, but the model goes from never attending to the
=
token to significantly attending to it. Presumably there’s some further circuit development here?Well, new to me—ancedotally, a friend of mine came up with the algorithm independently when exploring grokking
Though skip trigrams are extremely easy to interpret, and mostly match what I’d predict, I just haven’t written it up.
I’ve looked at several models, but not exhaustively
Note that excluded loss specifically is weird, because we fit the coefficients of cos(w(x+y)) to the logits across all of the data, so it’s technically also a function of the test data. Probably linear regression on just the train logits would suffice though.
Though plausibly helps by adding regularisation—generalising solution may be more stable to slingshots than memorised ones?
Input format is the sum of a one-hot encoding of x and of y in R113. If x=y then the vector has a single two in it.