[Edit Jan 19, 2023: I no longer think the below is accurate. My argument rests on an unstated assumption: that when weight decay kicks in, the counter-pressure against it is stronger for the 101th weight (the “bias/generalizer”) than the other weights (the “memorizers”) since the gradient is stronger in that direction. In fact, this mostly isn’t true, for the same reason Adam(W) moved towards the 12M+12G solution to begin with before weight decay strongly kicked in: each dimension of the gradient is normalized relative to its typical magnitudes in the past. Hence the counter-pressure is the same on all coordinates.
A caveat: this assumes we’re doing full batch AdamW, as opposed to randomized minibatches. In the latter case, the increased noise in the “memorizer” weights will in fact cause Adam to be less confident about those weights, and thus assign less magnitude to them. But this happens essentially right from the start, so it doesn’t really explain grokking.
Here’s an example of this, taking randomized minibatches of size 10 (out of 100 total) on each step, optimizing with AdamW (learning rate = 0.001, weight decay = 0.01). I show the first three “memorizer” weights (out of 100 total) plus the bias:
As you can see, it does place less magnitude on the memorizers due to the increased noise, but this happens right from the get-go; it never “groks”.
If we do full batch AdamW, then the bias is treated indistinguishably:
For small weight decay settings and zero gradient noise, AdamW is doing something like finding the minimum-norm solution, but in L∞space, not L2 space.]
Here’s a straightforward argument that phase changes are an artifact of AdamW and won’t be seen with SGD (or SGD with momentum).
Suppose we have 101 weights all initialized to 0 in a linear model, and two possible ways to fit the training data:
The first is M=[1,1,…,1,1,0]. (It sets the first 100 weights to 1, and the last one to 0.)
The second is G=[0,0,…,0,0,1]. (It sets the first 100 weights to 0, and the last one to 1.)
(Any combination aM+bG with a+b=1 will also fit the training data.)
Intuitively, the first solution M memorizes the training data: we can imagine that each of the first 100 weights corresponds to storing the value of one of the 100 samples in our training set. The second solution G is a simple, generalized algorithm for solving all instances from the underlying data distribution, whether in the training set or not.
M has an L2 norm which is ten times as large as G. SGD, since it follows the gradient directly, will mostly move directly toward G as it’s the direction of steeper descent. It will ultimately converge on the minimum norm solution 1101M+100101G. (Momentum won’t change this picture much, since it’s just smearing out each SGD step over multiple updates, and each individual SGD step goes in the direction of steepest descent.)
AdamW, on the other hand, is basically the same as Adam at first, since L2 weight decay doesn’t do much when the weights are small. Since Adam is a scale-invariant, coordinate-wise adaptive learning rate algorithm, it will move at the same speed for each of the 101 coordinates in the direction which reduces loss, moving towards the solution 12M+12G, i.e. with heavy weight on the memorization solution. Weight decay will start to kick in a bit before this point, and over time AdamW will converge to (close to) the same minimum-norm solution 1101M+100101G as SGD. This is the phase transition from memorization to generalization.
This is a straightforward argument to test this idea/argument!
I think you may be correct in that regular SGD has an implicit L2 simplicity prior that Adam family optimizers trade for the more important robustness at speed (regular SGD is just completely wrong and fails hard for many compute graphs that don’t fit the restrictive zero mean zero variance statistics assumptions). And AdamW with weight complexity prior behaves differently enough that the phase transitions could be specific to AdamW.
But maybe not? Even with pure SGD, you can get nonlinear dynamics through increasing depth (trivial example, consider a deep network of stacked repeated layers—the gradient updates become increasingly multiplicative/lognormal with depth rather than additive/normal).
I agree with you that that’s possible under some conditions and that my argument is not a proof. I note, however, that for the large neural networks used today, there’s theory and evidence (at least under SGD) supporting the idea that they’re effectively approximately linear models (that is, linear in their weights, not linear in their inputs), because the higher-order (multiplicative) effects you describe don’t matter since total weight updates throughout training are small. (Actually, I suspect the proof doesn’t hold for Adam.)
Even in the case of a nonlinear model, SGD has, as you say, an implicit prior for small movements as measured by the L2 norm, whereas Adam has an implicit prior for small movements as measured by the L∞ norm, but then as weight decay becomes more significant throughout training, you start to get this competition from the L2 norm, and once you interpolate the training data, that becomes all you care about.
It’s worth noting further that the volume of an L∞ ball of a given constant radius increases much faster with dimension than the volume of an L2 ball. This implies that the effective capacity of a high-parameter network trained by Adam for N steps is much larger than that same network trained by SGD for N steps with a comparable or smaller learning rate and gradient not too large (these assumptions typically hold). Thus, the “small L∞ movement” prior of Adam is barely a simplicity prior at all, in relative terms.
The “large ANNs are effectively linear” predicts that the brain would be even more linear, as it is far wider. And indeed there is some interesting indirect evidence supporting that—at the infinite width limit full backprop becomes unnecessary and equivalent to simpler feedback alignment models. This ties in nicely with that research track trying to explain how the brain learns using SGD like optimization sans backprop.
But on the other hand this linear-limit neural tangent model only applies for simple networks, which are pretty limited. I’ll define complex networks as those which have multiplicative interactions in the forward pass. Attention/routing, short/medium term memory, gating etc are all computationally equivalent in their general forms, and all require multiplicative interactions in the forward pass (ie the transpose multiply enabling sequence memory/attention in transformers).
Multiplicative non-linearity seems necessary for high capability/efficiency but also makes training dynamics non-linear—small changes in hidden state (and thus weights) can amplify to very large effects in outptuts/gradients etc. (With sparse memory operations being the extreme example of maximally efficient and maximally nonlinear)
Hmm, I think I understand what you’re pointing at but it’s not obvious to me that the conclusion is correct. If I wear my “infinite hidden width extremist” hat, I’d say that the network after training has extremely similar hidden activations on input x as the network before training. It’s just that the hidden activations have moved in a coordinated way so as to make the output layer come out very differently.
So yeah, the nonlinearities are all there, but they’re fixed nonlinearities of hidden features, and the network’s job is to learn the right linear combination of those fixed nonlinear features.
I’m not confident that this will hold in transformer networks, but I’m not confident it won’t either. Keep in mind that MLPs can learn to multiply, but (if sufficiently wide) they’re still effectively linear models. So the mere existence of nonlinear, multiplicative interactions as a function of the input doesn’t guarantee nonlinearity in the weights.
MLPs learning to multiply binary digits is mostly unrelated. The difference I am talking about (simple linear vs complex non-linear) is perhaps better illustrated by considering networks with exponent activation functions rather than relu. Relu is by-design about as close to linear as one can get while still being useful. With exp activation functions the outputs/gradients are rather obviously non-linear in weights.
Another example again is extreme activation sparsity through k-max hidden layers with low k.
MLPs can learn to multiply general real numbers, not just binary digits, so long as the inputs are bounded. I’m actually not clear on why that example is mostly unrelated. It illustrates that you can have an arbitrary nonlinear circuit in part of the network while still being effectively linear in terms of weights, due to the weights staying in a small neighborhood of initialization. It’s actually not at all obvious to me that exponential activation functions would ruin this property. In fact I suspect they don’t in the infinite width limit, although that infinite width limit might be a worse approximation in practice.
Note that the question is ultimately not whether the network is truly linear in weights, but whether it’s effectively linear in weights over the range they move in. A nonlinear smooth function can be usefully treated as linear if we constrain ourselves to a small enough neighborhood. What’s not obvious to me is whether this approximation works for transformers. I wouldn’t be surprised either way.
[Edit Jan 19, 2023: I no longer think the below is accurate. My argument rests on an unstated assumption: that when weight decay kicks in, the counter-pressure against it is stronger for the 101th weight (the “bias/generalizer”) than the other weights (the “memorizers”) since the gradient is stronger in that direction. In fact, this mostly isn’t true, for the same reason Adam(W) moved towards the 12M+12G solution to begin with before weight decay strongly kicked in: each dimension of the gradient is normalized relative to its typical magnitudes in the past. Hence the counter-pressure is the same on all coordinates.
A caveat: this assumes we’re doing full batch AdamW, as opposed to randomized minibatches. In the latter case, the increased noise in the “memorizer” weights will in fact cause Adam to be less confident about those weights, and thus assign less magnitude to them. But this happens essentially right from the start, so it doesn’t really explain grokking.
Here’s an example of this, taking randomized minibatches of size 10 (out of 100 total) on each step, optimizing with AdamW (learning rate = 0.001, weight decay = 0.01). I show the first three “memorizer” weights (out of 100 total) plus the bias:
As you can see, it does place less magnitude on the memorizers due to the increased noise, but this happens right from the get-go; it never “groks”.
If we do full batch AdamW, then the bias is treated indistinguishably:
For small weight decay settings and zero gradient noise, AdamW is doing something like finding the minimum-norm solution, but in L∞ space, not L2 space.]
Here’s a straightforward argument that phase changes are an artifact of AdamW and won’t be seen with SGD (or SGD with momentum).
Suppose we have 101 weights all initialized to 0 in a linear model, and two possible ways to fit the training data:
The first is M=[1,1,…,1,1,0]. (It sets the first 100 weights to 1, and the last one to 0.)
The second is G=[0,0,…,0,0,1]. (It sets the first 100 weights to 0, and the last one to 1.)
(Any combination aM+bG with a+b=1 will also fit the training data.)
Intuitively, the first solution M memorizes the training data: we can imagine that each of the first 100 weights corresponds to storing the value of one of the 100 samples in our training set. The second solution G is a simple, generalized algorithm for solving all instances from the underlying data distribution, whether in the training set or not.
M has an L2 norm which is ten times as large as G. SGD, since it follows the gradient directly, will mostly move directly toward G as it’s the direction of steeper descent. It will ultimately converge on the minimum norm solution 1101M+100101G. (Momentum won’t change this picture much, since it’s just smearing out each SGD step over multiple updates, and each individual SGD step goes in the direction of steepest descent.)
AdamW, on the other hand, is basically the same as Adam at first, since L2 weight decay doesn’t do much when the weights are small. Since Adam is a scale-invariant, coordinate-wise adaptive learning rate algorithm, it will move at the same speed for each of the 101 coordinates in the direction which reduces loss, moving towards the solution 12M+12G, i.e. with heavy weight on the memorization solution. Weight decay will start to kick in a bit before this point, and over time AdamW will converge to (close to) the same minimum-norm solution 1101M+100101G as SGD. This is the phase transition from memorization to generalization.
This is a straightforward argument to test this idea/argument!
I think you may be correct in that regular SGD has an implicit L2 simplicity prior that Adam family optimizers trade for the more important robustness at speed (regular SGD is just completely wrong and fails hard for many compute graphs that don’t fit the restrictive zero mean zero variance statistics assumptions). And AdamW with weight complexity prior behaves differently enough that the phase transitions could be specific to AdamW.
But maybe not? Even with pure SGD, you can get nonlinear dynamics through increasing depth (trivial example, consider a deep network of stacked repeated layers—the gradient updates become increasingly multiplicative/lognormal with depth rather than additive/normal).
I agree with you that that’s possible under some conditions and that my argument is not a proof. I note, however, that for the large neural networks used today, there’s theory and evidence (at least under SGD) supporting the idea that they’re effectively approximately linear models (that is, linear in their weights, not linear in their inputs), because the higher-order (multiplicative) effects you describe don’t matter since total weight updates throughout training are small. (Actually, I suspect the proof doesn’t hold for Adam.)
Even in the case of a nonlinear model, SGD has, as you say, an implicit prior for small movements as measured by the L2 norm, whereas Adam has an implicit prior for small movements as measured by the L∞ norm, but then as weight decay becomes more significant throughout training, you start to get this competition from the L2 norm, and once you interpolate the training data, that becomes all you care about.
It’s worth noting further that the volume of an L∞ ball of a given constant radius increases much faster with dimension than the volume of an L2 ball. This implies that the effective capacity of a high-parameter network trained by Adam for N steps is much larger than that same network trained by SGD for N steps with a comparable or smaller learning rate and gradient not too large (these assumptions typically hold). Thus, the “small L∞ movement” prior of Adam is barely a simplicity prior at all, in relative terms.
The “large ANNs are effectively linear” predicts that the brain would be even more linear, as it is far wider. And indeed there is some interesting indirect evidence supporting that—at the infinite width limit full backprop becomes unnecessary and equivalent to simpler feedback alignment models. This ties in nicely with that research track trying to explain how the brain learns using SGD like optimization sans backprop.
But on the other hand this linear-limit neural tangent model only applies for simple networks, which are pretty limited. I’ll define complex networks as those which have multiplicative interactions in the forward pass. Attention/routing, short/medium term memory, gating etc are all computationally equivalent in their general forms, and all require multiplicative interactions in the forward pass (ie the transpose multiply enabling sequence memory/attention in transformers).
Multiplicative non-linearity seems necessary for high capability/efficiency but also makes training dynamics non-linear—small changes in hidden state (and thus weights) can amplify to very large effects in outptuts/gradients etc. (With sparse memory operations being the extreme example of maximally efficient and maximally nonlinear)
Hmm, I think I understand what you’re pointing at but it’s not obvious to me that the conclusion is correct. If I wear my “infinite hidden width extremist” hat, I’d say that the network after training has extremely similar hidden activations on input x as the network before training. It’s just that the hidden activations have moved in a coordinated way so as to make the output layer come out very differently.
So yeah, the nonlinearities are all there, but they’re fixed nonlinearities of hidden features, and the network’s job is to learn the right linear combination of those fixed nonlinear features.
I’m not confident that this will hold in transformer networks, but I’m not confident it won’t either. Keep in mind that MLPs can learn to multiply, but (if sufficiently wide) they’re still effectively linear models. So the mere existence of nonlinear, multiplicative interactions as a function of the input doesn’t guarantee nonlinearity in the weights.
MLPs learning to multiply binary digits is mostly unrelated. The difference I am talking about (simple linear vs complex non-linear) is perhaps better illustrated by considering networks with exponent activation functions rather than relu. Relu is by-design about as close to linear as one can get while still being useful. With exp activation functions the outputs/gradients are rather obviously non-linear in weights.
Another example again is extreme activation sparsity through k-max hidden layers with low k.
MLPs can learn to multiply general real numbers, not just binary digits, so long as the inputs are bounded. I’m actually not clear on why that example is mostly unrelated. It illustrates that you can have an arbitrary nonlinear circuit in part of the network while still being effectively linear in terms of weights, due to the weights staying in a small neighborhood of initialization. It’s actually not at all obvious to me that exponential activation functions would ruin this property. In fact I suspect they don’t in the infinite width limit, although that infinite width limit might be a worse approximation in practice.
Note that the question is ultimately not whether the network is truly linear in weights, but whether it’s effectively linear in weights over the range they move in. A nonlinear smooth function can be usefully treated as linear if we constrain ourselves to a small enough neighborhood. What’s not obvious to me is whether this approximation works for transformers. I wouldn’t be surprised either way.