One of the authors of the paper here. Glad you found it interesting! In case people want to mess around with some of our results themselves, here are colab notebooks for reproducing a couple results:
On some level "just fix your weight norm and the model generalizes" sounds too simple to be true for all tasks -- I agree. I’d be pretty surprised if our result on speeding up generalization on modular arithmetic by constraining weight norm had much relevance to training large language models, for instance. But I haven’t thought much about this yet!
In terms of relevance to AI safety, I view this work broadly as contributing to a scientific understanding of emergence in ML c.f. “More is Different for AI”. It seems useful for us to understand mechanistically how/why surprising capabilities are gained in increasing model scale or training time (as is the case for grokking), so that we can better reason about and anticipate the potential capabilities and risks of future AI systems. Another AI safety angle could lie in trying to unify our observations with Nanda and Lieberum’s circuits-based perspective on grokking. My understanding of that work is that networks learn both memorizing and generalizing circuits, and that generalization corresponds to the network eventually “cleaning up” the memorizing circuit, leaving the generalizing circuits. By constraining weight norm, are we just preventing the memorizing circuits from forming? If so, can we learn something about circuits, or auto-discover them, by looking at properties of the loss landscape? In our setup, does switching to polar coordinates factor the parameter space into things which generalize and things which memorize, with the radial direction corresponding to memorization and the angular directions corresponding to generalization? Maybe there are general lessons here.
Razied’s comment makes a good point about weight L2 norm being a bizarre metric for generalization, since you can take a ReLU network which generalizes and arbitrarily increase its weight norm by multiplying neuron in-weights by β and its out-weights by 1/β without changing the function implemented by the network. The relationship between weight norm and generalization is an imperfect one. What we find empirically is simply this: when we initialize networks in a standard way, multiply all the parameters by α, and then constrain optimization to lie on that constant-norm sphere in parameter space, there is often an α-dependent gap in test and train performance for the solutions that optimizers find. For large α, optimization finds a solution on the sphere which fits the training data but doesn’t generalize. For α in the right range, optimization finds a solution on the sphere which does generalize. So maybe the right statement about generalization and weight norm is more about the density of generalizing vs not generalizing solutions in different regions of parameter space, rather than their existence. I’ll also point out that this gap between train and test performance as a function of α is often only present when we reduce the size of the training dataset. I don’t yet understand mechanistically why this last part is true.
One of the authors of the paper here. Glad you found it interesting! In case people want to mess around with some of our results themselves, here are colab notebooks for reproducing a couple results:
Delaying generalization (inducing grokking) on MNIST: https://colab.research.google.com/drive/1wLkyHadyWiZSwaR0skJ7NypiYKCiM7CR?usp=sharing
Almost eliminating grokking (bringing train and test curves together) in transformers trained on modular addition: https://colab.research.google.com/drive/1NsoM0gao97jqt0gN64KCsomsPoqNlAi4?usp=sharing
Some miscellaneous comments:
On some level "just fix your weight norm and the model generalizes" sounds too simple to be true for all tasks
-- I agree. I’d be pretty surprised if our result on speeding up generalization on modular arithmetic by constraining weight norm had much relevance to training large language models, for instance. But I haven’t thought much about this yet!In terms of relevance to AI safety, I view this work broadly as contributing to a scientific understanding of emergence in ML c.f. “More is Different for AI”. It seems useful for us to understand mechanistically how/why surprising capabilities are gained in increasing model scale or training time (as is the case for grokking), so that we can better reason about and anticipate the potential capabilities and risks of future AI systems. Another AI safety angle could lie in trying to unify our observations with Nanda and Lieberum’s circuits-based perspective on grokking. My understanding of that work is that networks learn both memorizing and generalizing circuits, and that generalization corresponds to the network eventually “cleaning up” the memorizing circuit, leaving the generalizing circuits. By constraining weight norm, are we just preventing the memorizing circuits from forming? If so, can we learn something about circuits, or auto-discover them, by looking at properties of the loss landscape? In our setup, does switching to polar coordinates factor the parameter space into things which generalize and things which memorize, with the radial direction corresponding to memorization and the angular directions corresponding to generalization? Maybe there are general lessons here.
Razied’s comment makes a good point about weight L2 norm being a bizarre metric for generalization, since you can take a ReLU network which generalizes and arbitrarily increase its weight norm by multiplying neuron in-weights by β and its out-weights by 1/β without changing the function implemented by the network. The relationship between weight norm and generalization is an imperfect one. What we find empirically is simply this: when we initialize networks in a standard way, multiply all the parameters by α, and then constrain optimization to lie on that constant-norm sphere in parameter space, there is often an α-dependent gap in test and train performance for the solutions that optimizers find. For large α, optimization finds a solution on the sphere which fits the training data but doesn’t generalize. For α in the right range, optimization finds a solution on the sphere which does generalize. So maybe the right statement about generalization and weight norm is more about the density of generalizing vs not generalizing solutions in different regions of parameter space, rather than their existence. I’ll also point out that this gap between train and test performance as a function of α is often only present when we reduce the size of the training dataset. I don’t yet understand mechanistically why this last part is true.
Thanks for all the clarifications and the notebook. I’ll definitely play around with this :)