Hypothesis: gradient descent prefers general circuits

Summary: I discuss a potential mechanistic explanation for why SGD might prefer general circuits for generating model outputs. I use this preference to explain how models can learn to generalize even after overfitting to near zero training error (i.e., grokking). I also discuss other perspectives on grokking and deep learning generalization.

Additionally, I discuss potential experiments to confirm or reject my hypothesis. I suggest that a tendency to unify many shallow patterns into fewer general patterns is a core feature of effective learning systems, potentially including humans and future AI, and briefly address implications to AI alignment.

Epistemic status: I think the hypothesis I present makes a lot of sense and is probably true, but I haven’t confirmed things experimentally. Much of my motive for post this is to clarify my own thinking and get feedback on the best ways to experimentally validate this perspective on ML generalization.

Context about circuits: This post assumes the reader is familiar with and accepts the circuits perspective on deep learning. See here for a discussion of circuits for CNN vision models and here for a discussion of circuits for transformer NLP models.

Evidence from grokking

The paper “Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets” uses stochastic gradient descent (SGD) to train self attention based deep learning models on different modular arithmetic expressions (e.g., , where is fixed).

The training data only contain subsets of the function’s possible input/​output pairs. Initially, the models overfit to their training data and are unable to generalize to the validation input/​output pairs. In fact, the models quickly reach near perfect accuracy on their training data. However, training the model for significantly past the point of overfitting causes the model to generalize to the validation data, what the authors call “grokking”.

See figure 1a from the paper:

Figure 1a from Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets

Note the model has near perfect accuracy on the training data. Thanks to a recent replication of this work, we can also look at the loss curves during grokking (though on a different experiment compared to the plot above):

Figure from this GitHub repository

First, the model reaches near-zero loss in training but overfits in validation. However, the validation loss soon starts decreasing until the model correctly classifies both the validation and training data.

This brings up an interesting question: why did the model learn anything at all after reaching near zero loss on the training data? Why not just stick with memorizing the training data? What would prompt SGD to switch over to general circuitry that solves both training and validation data?

I think the answer is surprisingly straightforward: SGD prefers general circuits because general circuits make predictions on a greater fraction of the training data. Thus, general circuits receive more frequent SGD reinforcement for making correct predictions. Think of each data point as “pushing” the model to form circuits that perform well on that data point. General circuits perform well on many data points, so they receive a greater “push” towards forming.

Shallow circuits are easier to form with SGD, but they aren’t retained as strongly. Thus, as training progresses, general circuits eventually overtake the shallow circuits.

Toy example of SGD preferring general circuits

Let’s consider a toy example of memorizing two input/​output pairs. Suppose:

  • input 1 =

  • output 1 =

  • input 2 =

  • output 2 =

One way the model might memorize the these data points is to use two independent, shallow circuits, one for each data point. I show a diagram of how this might be implemented using two different self attention heads:

Shallow memorization circuits.

(Suppose for simplicity that these attention heads ONLY implement the circuit shown)

and represent the query-key[1] circuits associated with their associated attention heads. and are respectively searching for and appearing in the input, and trigger their respective output-value[1] circuits—represented by and - when they find their desired inputs.

Another way to memorize these data points is to use a more general, combined circuit implemented with a single attention head:

More general memorization circuit.

Here, represents a single query-key circuit that looks for either or in the input and triggers the output-value circuit to produce either output 1 or output 2 depending on the triggering input.

I think SGD will prefer the single general circuit to the shallow circuits because the general circuit produces correct predictions on a greater fraction of the input examples. SGD only reinforces one of the shallow circuits when the model processes the specific input associated with that circuit. In contrast, SGD reinforces the general circuit whenever the model processes either of the inputs for which the general circuit produces correct predictions.

To clarify: both circuit configurations described here memorize the data, and both would likely fail completely on validation data. The single circuit that memorizes both datapoints is more “general” in the sense that it generates more total correct predictions. I use memorizing circuits as my examples here to highlight the fact that I use “generality” to refer specifically to the number of training datapoints for which a circuit generates correct predictions, not, say, the probability that a circuit generalizes from the training data to the validation data.

Another way to see why SGD would prefer the general circuit: catastrophic forgetting is the tendency of models initially trained on task A, then trained on task B to forget task A while learning task B. Consider that, if the model isn’t processing inputs containing , the individaul circuit that produces output 1 will experience catestrophic forgetting. Thus, all training examples except one are degrading the shallow circuit’s performance.

In contrast, the general circuit generates predictions for both and . It’s reinforced twice as frequently, so it’s better able to recover from degradation caused by training on the other examples. Eventually, the general circuit subsumes the functionality of the two shallow circuits.

From figure 2a of “Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets”, we can see that stochasticity and regularization significantly speeds up generalization. Potentially, this occurs because both randomness in the SGD updates and weight decay help to degrade shallow circuits, allowing general circuits to more quickly dominate. I think a large amount of weight decay’s value in other domains comes from it degrading shallow circuits more than general circuits.

Figure 2a of “Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets

I think the conceptual arguments I’ve provided above strongly imply that SGD has some sort of preference for general circuits over a functionally equivalent collection of shallow circuits. In this next section, I try to flush out in more detail how this might manifest in the training process. However, I think this section is more speculative than the arguments above.

Significance for the entire model

The process of unifying multiple shallow circuits into fewer, more general circuits happens at multiple levels throughout the training process. Gradually, the shallow circuits combine into slightly more general circuits, which themselves combine further. Eventually, all the shallow memorization circuits combine into a single circuit representing the true modular arithmetic expression.

I think we see artifacts of this combination process in the grokking loss plots, specifically in the spikes:

Figure from this GitHub repository

Note that each loss spike in the training data corresponds with a loss spike in the validation data. I think these spikes represent unification events where the model replaces multiple shallow circuits with a smaller number of more general circuits.

The previous section described general vs shallow circuits as a binary choice, with the model able to use either one general circuit or a collection of shallow circuits. However, real deep learning models are more complex. They can simultaneously implement multiple instances of both types of circuits, with each circuit being partially responsible for a part of a single prediction.

Let’s consider and representing a subset of the training data.

For the start of the training, I think each prediction on the elements of is mainly generated by multiple different shallow circuits, with some small fraction of each prediction coming from a single partially implemented general circuit.

As training progresses, the model gradually refines whatever general circuit contributes to correct predictions on all of . Eventually, the model reaches an inflection point where it has a general circuit that can correctly predict all of . At this point, I think there’s a relatively quick phase shift in which the general circuit substitutes in for multiple shallow circuits at once. These shifts generate the loss spikes seen in the plot above.

I’m unsure why switching from shallow to general circuits would cause a loss spike. I suspect the network is encountering something like a renormalization issue. The general circuit may be generating predictions for , but that doesn’t mean that all of the shallow circuits have been removed. If there are elements of where both the general circuit and its original shallow circuit generate predictions, that may cause the network to behave poorly on those data points.

Generalization to validation data starts to happen when the only way for the model to fit more correct predictions into a single circuit is for that circuit to actually start modeling the underlying data generating process. This leads to circuits that are still shallow, but have some limited generalization capability.

To see how a model might implement partially generalizing shallow patterns, imagine a circuit using a linear approximation to as to efficiently store many predictions on the training data. Note this approximation is correct so long as , so it does have some degree of generalization to validation data, even though the model only learned it for training data. Similarly, the network can use for any .

Midway through grokking, the network probably looks an ensemble of circuits that each represent the true data distribution in different areas of the input space. Eventually, these partially generalizing shallow patterns combine together into a single circuit that correctly represents the true arithmetic expression (e.g., by first computing as a function of , then feeding that into .

Other explanations of grokking

Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets doesn’t offer any particular interpretation of their results. However, Rohin’s summary of the paper proposed the following:

1. Functions that perfectly memorize the data without generalizing (i.e. probability 1 on the true answer and 0 elsewhere) are very complicated, nonlinear, and wonky. The memorizing functions learned by deep learning don’t get all the way there and instead assign a probability of (say) 0.95 to the true answer.

2. The correctly generalizing function is much simpler and for that reason can be easily pushed by deep learning to give a probability of 0.99 to the true answer.

3. Gradient descent quickly gets to a memorizing function, and then moves mostly randomly through the space, but once it hits upon the correctly generalizing function (or something close enough to it), it very quickly becomes confident in it, getting to probability 0.99 and then never moving very much again.

I strongly prefer my account of grokking for several reasons.

  • I think memorization is actually straightforward for a network to do. You just need a key-value system where the key detects the embedding associated with a given input, then activates the value, which produces the correct output for the detected input. Such systems are easy to implement in many ML architectures (feed forward, convolution, self attention, etc).

  • I think a story of grokking that involves incrementally increasing the generality by iteratively combining multiple shallow circuits does a better job of explaining away the complexity of the resulting model. It requires less in the way of complex structures suddenly emerging from scratch.

  • My account is more similar to biological evolution, which we know incrementally builds up complex structures from simpler predecessors.

  • My account directly predicts that stochasticity and weight decay regularization would help with generalization, and even predicts that weight decay would be one of the most effective interventions to improve generalization.

  • Additionally, it’s not the case that “once it hits upon the correctly generalizing function (or something close enough to it), it very quickly becomes confident in it”. This is an illusion caused by the log scale on the x-axis of the plots. If we look closely at the figure below, the accuracy on the validation data starts to increase at ~ step , and roughly levels off at ~ step . This is a span of steps and represents about 34 of the entire training process. If the model stumbles upon a single general circuit that solves the entire problem, then you’d expect it to make the switch very quickly. In contrast, if it has to go through multiple rounds of unifying shallow circuits into more general circuits, then you’d expect that process to take a while and for the model to gradually increase in generality throughout. The plot is more consistent with the latter interpretation.

Figure 1a from Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets
  • Finally, if we look at a loss plot on a log scale, we can see that the validation loss starts decreasing at ~ step , while the floor on the minimum training loss remains fairly constant (or even increases) until slightly after that (~step ). Thus, validation loss starts decreasing thousands of steps before training loss starts decreasing. Whatever is causing the generalization, it’s not doing so to decrease training loss (at least not at first).

Log scale of train/​validation loss during grokking. Generated using this GitHub repository.

(I also think it’s interesting how regular the training loss spikes look from steps to . There’s a sharp jump, followed by a sharp decrease, then a shallower decrease, then another spike almost immediately after. I have no idea what to make of it, but whatever underlying mechanism drives this behavior should be interesting.)

Rohin’s explanations is the only other attempted explanation of grokking I’ve seen. Please let me know of any more in the comments.

Other explanations of generalization

Prior work has proposed that neural network generalization happens primarily as a result of neural network initializations strongly favoring simpler functions, with relatively little inductive bias coming from the optimization procedure.

E.g., Is SGD a Bayesian sampler? Well, almost demonstrated that if you randomly sample neural network initializations until you find one that has low error on a training set, that network will generalize to the test data. Additionally, the test set predictions made by the randomly sampled classifier will correlate strongly with the test set predictions make by a classifier learned via SGD.

I think such work clearly demonstrates that network initializations are strongly biased towards simple functions. However, I think these results are compatible with SGD having a bias towards general circuits.

For one, common patterns in data may explain common patterns in models fit to that data. The correlation between SGD learned and randomly sampled classifiers seem to be a specific instance of the tendency for many types of learning to converge to exhibit similar behavior when trained on similar data. I.e., both SGD and randomly sampled classifiers seem less likely to fit outlier datapoints and more likely to fit tightly clustered datapoints.

Additionally, generality bias seems similar to simplicity bias. Occam’s razor implies simpler circuits are more likely to generalize. All else equal, general circuit are more likely to be simple. Potentially, the only difference between the “simplicity bias from initialization” vs “simplicity bias from initialization and generality bias from SGD” perspectives on generalization is that the latter implies faster learning than the former. Perhaps not-coincidentally, SGD is one of the fastest ways to train neural nets.

Experimental investigations

Given a particular output from a neural net, there are methods of determining which neurons are most responsible for generating that output. Such scores are often called the “saliency” of a given neuron for a particular output. Example methods include integrated gradients and Shapley values.

I think we can find experimental evidence for or against the general circuits hypothesis by looking at how the distribution over neuron saliencies evolves during training. When shallow circuits dominate network behavior, each neuron will mostly be salient for generating a small fraction of the outputs. However, as more general circuits form, there should be a small collection of neurons that become highly salient for many different outputs. We should be able to do something like look at a histogram of average neuron saliencies measured across many inputs.

One idea is to record neuron saliencies for each prediction in each training/​testing epoch, then compute the median saliency for each neuron in each epoch. After which, I’ll generate a histogram of the neurons’ median saliencies for each epoch. I should see the histogram becoming more and more right-skewed as the training progresses. This should happen because general circuits are salient for a greater fraction of the inputs.

Another idea would be to find the neurons with the highest saliency at each epoch, then test what happens if we delete them. As training progresses, individual circuits will become responsible for a greater fraction of the model’s predictions. We should find that this deletion operation damages more and more of the model’s predictive capability.

Both these ideas would provide evidence for general circuits replacing shallow circuits over time. However, they’d not show that the specific reason for this replacement was because general circuits made more predictions and so were favored by SGD. I’m unsure how to investigate this specific hypothesis. All I can think of is to identify some set of shallow circuits and a single general circuit that makes the same predictions as the set of shallow circuits. Then, record the predictions and gradients made by the shallow and general circuits and hope to find a clear, interpretable pattern of the general circuit receiving more frequent/​stronger updates and gradually replacing the shallow circuits.

(If anyone can think about other experimental investigations or has thoughts on this proposal, please share in the comments!)

Implications for learning in general

My guess is that many effective learning systems will have heuristics that cause them to favor circuits that make lots of predictions.

For example, many tendencies of the brain seem to promote general circuitry. Both memories and skills decay over time unless they’re periodically refreshed. When senses are lost, the brain regions corresponding to those senses are repurposed towards processing the remaining sense data.

In addition to using low-level processes that promote general circuitry, highly capable learning systems may develop a high-level tendency towards generalization because such a tendency is adaptive for many problems. In other words, they may learn to “mesa-generalize”[2].

I think humans show evidence of a mesa-generalization instinct. Consider that religions, ideologies, philosophical frameworks and conspiracy theories often try to explain a large fraction of the world through a single lens. Many such grand narratives make frequent predictions about hard to verify things. Without being able to easily verify those predictions, our mesa generalization instincts may favor narratives that make many predictions.

Potentially, ML systems will have a similar mesa-generalization instinct. This could be a good thing. Human philosophers have put quite a bit of effort into mesa-generalizing a universal theory of ethics. If ML systems are naturally inclined to do something similar, maybe we can try to point this process in the right direction?

Mesa-generalization from ML systems could also be dangerous, for much the same reason mesa-optimization is dangerous. We don’t know what sort of generalization instinct the system might adopt, and it could influence the system’s behaviors in ways that are hard to predict from the training data.

This seems related to the natural abstractions hypothesis. Mesa-generalization suggests an ML system should prefer frequently used abstractions. At a human level of capabilities, these should coincide reasonably well with human abstractions. However, more capable systems are presumably able to form superhuman abstractions that are used for a greater fraction of situations. This suggests we might have initially encouraging “alignment by default”-type results, only for the foundation of that approach to collapse as we reach superhuman capabilities.

  1. ^

    Essentially, the query-key circuit determines which input tokens are most important, and the output-value circuit determines which outputs to generate for each attended token. See “A Mathematical Framework for Transformer Circuits” for more details on query-key /​ output-value formulation of self attention.

  2. ^

    So named after “mesa-optimization”, the potential for learning systems to implement an optimization procedure as an adaptive element of their cognition. See Risks from Learned Optimization in Advanced Machine Learning Systems.