Visible loss landscape basins don’t correspond to distinct algorithms
Thanks to Justis, Arthur Conmy, Neel Nanda, Joseph Miller, and Tilman Räuker for their feedback on a draft.
I feel like many people haven’t noticed an important result of mechanistic interpretability analysis of grokking, and so haven’t updated how they think about loss landscapes and algorithms that neural networks end up implementing. I think this has implications for alignment research.
When thinking about grokking, people often imagine something like this: the neural network implements Algorithm 1 (e.g., memorizes the training data), achieves ~ the lowest loss available via memorization, then moves around the bottom of the Algorithm 1 basin and after a while, stumbles across a path to Algorithm 2 (e.g., the general algorithm for modular addition).
But the mechanistic interpretability of grokking analysis has shown that this is not true!
Approximately from the start of the training, Algorithm 1 is most of what the circuits are doing and what almost entirely determines the neural network’s output; but at the same time, the entire time the neural network’s parameters visibly move down the wider basin, they don’t just become better at memorization; they increasingly implement the circuits for Algorithm 1 and the circuits for Algorithm 2, in superposition.
(Neel Nanda et al. have shown that the circuits that at the end implement the general algorithm for modular addition start forming approximately at the start of the training: the gradient was mostly an arrow towards memorization, but also, immediately from the initialization of the weights, a bit of an arrow pointing towards the general algorithm. The circuits were gradually tuned throughout the training. The noticeable change in the test loss starts occurring when the circuits are already almost right.)
A path through the loss landscape visible in 3D doesn’t correspond to how and what the neural network is actually learning. Almost all of the changes to the loss are due to the increasingly good implementation of Algorithm 1; but apparently, the entire time, the gradient also points towards some faraway implementation of Algorithm 2. Somehow, the direction in which Algorithm 2 lies is also visible to the derivative, and moving the parameters in the direction the gradient points means mostly increasingly implementing Algorithm 1, and also increasingly implementing the faraway Algorithm 2.
“Grokking”, visible in the test loss, is due to the change that happens when the parameters already implement Algorithm 2 accurately enough for the switch from mostly outputting the results of an implementation of Algorithm 1 to the results of an improving implementation of Algorithm 2 not to hurt the performance. Once it’s the case, the neural network puts more weight into Algorithm 2 and at the same time quickly tunes it to be even more accurate (which is increasingly easy as the output is increasingly determined by the implementation of Algorithm 2).
This is something many people seem to have missed. I did not expect it to be the case, was surprised, and updated how I think about loss landscapes.
Does this generalize?
Maybe. I’m not sure whether it’s correct to generalize from the mechanistic interpretability of grokking analysis to neural networks in general, real LLMs are under-parametrised while the grokking model is very over-parameterised, but I guess it might be reasonable to expect that this is how deep learning generally works.
People seem to think that multi-dimensional loss landscapes of neural networks have basins for specific algorithms, and neural networks get into these depending on how relatively large these basins are, which might be caused by how simple the algorithms are, how path-dependent their implementation might be, etc. I think this makes a wrong prediction for what happens in grokking.
Maybe there are better ways to think about what’s actually going on. Significantly visible basins correspond to the implementation of algorithms that currently influence performance the most. But the neural network might be implementing whatever algorithms will output predictions that’d have mutual information with whatever the gradient communicates, and the algorithms that you see are not necessarily the better algorithms that the neural network is already slowly implementing (but not yet heavily reliant on).
It might be helpful to imagine two independent basins for two algorithms: how much each algorithm reduces the loss and how well the neural network implements them. If you sum the two basins, then, if you look at an area of the loss landscape, you’ll mostly only notice the wider basin; but in a high-dimensional space, gradient descent might be going down both at the same time, and the combination might not interfere enough to prevent this from happening, so at the end you might end up with the optimal algorithm, even if for most of the training, you thought you were looking only at a suboptimal one.
Some implications for alignment
If you were imagining that the loss landscape looks like this:
then you might have hoped you could find a way to shape it so that some simple algorithm exhibiting aligned behaviour somehow has much larger basins so that you’d likely end up and remain in it even if some less aligned algorithms would achieve a better loss. You might also have hoped to use interpretability tools to understand what’s going on in the neural network, and what the algorithm it implements thinks about.
This might not work; speculatively, if misaligned algorithms can be implemented by the neural network and would perform better than the aligned algorithms you were hoping for, the neural network might end up implementing them no matter what were the visible basins. Your interpretability tools might not distinguish the gradual implementation of a misaligned algorithm from noise. Something seemingly aligned might be mostly responsible for the outputs, but if there’s an agent with different goals, that can achieve a lower loss, its implementation might be slowly building up the whole time. I think this adds another difficulty on top of the classical sharp left turn: you need to deal not only with the changes in a specific algorithm whose alignment doesn’t necessarily generalise together with capabilities, but also with the possibility of a generally capable algorithm that your neural network might be directly implementing that has never even had alignment properties. You might end up not noticing that 2% of activations of a neuron that you thought distinguishes cats from dogs are devoted to planning to kill you.
Further research
It might be valuable to explore this more. Do similar dynamics generally occur during training, especially in models that aren’t over-parameterised? If you reverse-engineer and understand the final algorithm, when and how has the neural network started implementing it?
Thanks for this post. It failed to dislodge any dogmas in me because I didn’t subscribe to the ones you attacked—so here are my dogmas, maybe they are under-the-surface-similar and you can attack them too?-
-Randomly initialized neural networks of size N are basically a big grab bag of random subnetworks of size <N—
Training tends to simultaneously modify all the subnetworks at once, in a sort of evolutionary process—subnetworks that contributed to success get strengthened and tweaked, and subnetworks that contribute to failure get weakened.-
-Eventually you have a network that performs very well in training—which probably means that it has at least one and possibly several subnetworks that perform well in training. This explains why you can usually prune away neurons without much performance degradation, and also why neural nets are so robust to small amounts of damage.--
Networks that perform well in training tend to also be capable in other nearby environments as well (“generalization”) because (a) the gods love simplicity, and made our universe such that simple patterns are ubiquitous, and (b) simpler algorithms occupy more phase space in a neural net (there are more possible settings of the parameters that implement simpler algorithms), so (conclusion) a trained neural network tends to do well in training by virtue of subnetworks that implement simple algorithms that match the patterns inherent in the training environment, and often these patterns are found outside the training environment (in ‘nearby’ or ‘similar’ environments) also, so often the trained neural networks “generalize.”—S
o why grokking? Well, sometimes there is a simple algorithm (e.g. correct modular arithmetic) that requires getting a lot of fiddly details right, and a more complex algorithm (e.g. memorizing a look-up table) that is modular and very easy to build up piece by piece. In these cases, both algorithms get randomly initialized in various subnetworks, but the simple ones are ‘broken’ and need ‘repair’ so to speak, and that takes a while because for the whole to work all the parts need to be just so, whereas the complex but modular ones can very quickly be hammered into shape since some parts being out of shape reduces performance only partially. Thus, after training long enough, the simple algorithm subnetworks finally get into shape and then come to dominate behavior (because they are simpler & therefore more numerous).
What do you make of the mechanistic mode connectivity, and linear connectivity papers then?
I didn’t read super carefully, but it seems like the former paper is saying that, for some definition of “mechanistically similar”:
Mechanistically dissimilar algorithms can be “mode connected”—that is, local minima-ish that are connected by a path of local minima (the paper proves this for their definition of “mechanistically similar”).
If two models aren’t linearly mode connected, then that means that they’re dissimilar (note that this is a conjecture, but I guess they probably find evidence for it).
I don’t think this is in much tension with the post?
My reading of the post says that two algorithms, with different generalization and structural properties, can lie in the same basin, and it uses evidence from our knowledge of the mechanisms behind grokking on synthetic data to make this point. But the above papers show that in more realistic settings empirically, two models lie in the same basin (up to permutation symmetries) if and only if they have similar generalization and structural properties.
I think they only check if they lie in linearly-connected bits of the same basin if they have similar generalization properties? E.g. Figure 4 of Mechanistic Mode Connectivity is titled “Non-Linear Mode Connectivity of Mechanistically Dissimilar Models” and the subtitle states that “quadratic paths can be easily identified to mode connect mechanistically dissimilar models[, and] linear paths cannot be identified, even after permutation”. Linear Connectivity Reveals Generalization Strategies seems to be focussed on linear mode connectivity, rather than more general mode connectivity.
Mea culpa: AFAICT, the ‘proof’ in Mechanistic Mode Connectivity fails. It basically goes:
Prior work has shown that under overparametrization, all global loss minimizers are mode connected.
Therefore, mechanistically distinct global loss minimizers are also mode connected.
The problem is that prior work made the assumption that for a net of the right size, there’s only one loss minimizer up to permutation—aka there are no mechanistically distinct loss minimizers.
[EDIT: the proof also cites Nguyen (2019) in support of its arguments. I haven’t checked the proof in Nguyen (2019), but if it holds up, it does substantiate the claim in Mechanistic Mode Connectivity—altho if I’m reading it correctly you need so much overparameterization that the neural net has a layer with as many hidden neurons as there are training data points.]
Update: I currently think that Nguyen (2019) proves the claim, but it actually requires a layer to have two hidden neurons per training example.
The second paper is just about linear connectivity, and does seem to suggest that linearly connected models run similar algorithms. But I guess I don’t expect neural net training to go in straight lines? (Altho I suppose momentum helps with this?)
I don’t see how the mechanistic interpretability of grokking analysis is evidence against this.
At the start of training, the modular addition network is quickly evolving to get increasingly better training loss by overfitting on the training data. Every time it gets an answer in the training set right that it didn’t before, it has to have moved from one behavioural manifold in the loss landscape to another. It’s evolved a new tiny piece of circuitry, making it no longer the same algorithm it was a couple of batch updates ago.
Eventually, it reaches the zero loss manifold. This is a mostly fully connected subset of parameter space. I currently like to visualise it like a canyon landscape, though in truth it is much more high dimensional. It is made of many basins, some broad (high dimensional), some narrow (low dimensional), connected by paths, some straight, some winding.
In the broad basin picture, there aren’t just two algorithms here, but many. Every time the neural network constructs a new internal elementary piece of circuitry, that corresponds to moving from one basin in this canyon landscape to another. Between the point where the loss flatlines and the point where grokking happens, the network is moving through dozens of different basins or more. Eventually, it arrives at the largest, most high dimensional basin in the landscape, and there it stays.
I think this might be the source of confusion here. Until grokking finishes, the network isn’t even in that basin yet. You can’t be in multiple basins simultaneously.
At the time the network is learning the pieces of what you refer to as algorithm 2, it is not yet in the basin of algorithm 2. Likewise, if you went into the finished network sitting in the basin of algorithm 2 and added some additional internal piece of circuitry into it by changing the parameters, that would take it out of the basin of algorithm 2 and into a different, narrower one. Because it’s not the same algorithm any more. It’s got a higher effective parameter count now, a bigger Real Log Canonical Threshold.
Points in the same basin correspond to the same algorithm. But it really does have to be the same algorithm. The definition is quite strict here. What you refer to as superpositions of algorithm 1 and algorithm 2 are all various different basins in parameter space. Basins are regions where every point maps to the same algorithm, and all of those superpositions are different algorithms.
Doesn’t Figure 7, top left from the arXiv paper provide evidence against the “network is moving through dozens of different basins or more” picture?
… No?
I don’t see what part of the graphs would lead to that conclusion. As the paper says, there’s a memorization, circuit formation and cleanup phase. Everywhere along these lines in the three phases, the network is building up or removing pieces of internal circuitry. Every time an elementary piece of circuitry is added or removed, that corresponds to moving into a different basin (convex subset?).
Points in the same basin are related by internal symmetries. They correspond to the same algorithm, not just in the sense of having the same input-output behavior on the training data (all points on the zero loss manifold have that in common), but also in sharing common intermediate representations. If one solution has a piece of circuitry another doesn’t, they can’t be part of the same basin. Because you can’t transform them into each other through internal symmetries.
So the network is moving through different basins all along those graphs.
An even better mental picture to push back on might be an “annotated” version of the singular learning theory toy picture, where by “annotated” I mean that you’ve put mental labels on the low-dimensional minima as “algorithm 1”, and the high-dimensional minima as “algorithm 2″. The minima don’t have to correspond to a single algorithm as we conceive it.
So that’s likely why it works at all, and why larger and deeper networks are required to discover good generalities. You would think a larger and deeper network would have more weights to just memorize the answers but apparently you need it to explore multiple hypotheses in parallel.
At a high level this creates many redundant circuits that are using the same strategy, though, I wonder if there is a way to identify this and randomize the duplicates, causing the network to explore a more diverse set of hypotheses.