Hmm, I haven’t read the paper yet, but thinking about it, the easiest way to change the weight norm is just to multiply all weights by some factor α, but then in a network with ReLU activations and L layers, this would be completely equivalent to multiplying the output of the network by αL. In the overfitting regime where the network produces probability distributions where the mode is equal to the answer for all training tokens, the easiest way to decrease loss is just to multiply the output by a constant factor, essentially decreasing the entropy of the distribution forever, but this strategy fails at test-time because the mode is not equal to the answer there. So keeping the weight norm at some specified level might just be a way to prevent the network from taking the easy way towards decreasing training loss, and forcing it to find ways at constant-weight-norm to decrease loss, which would better generalize for the test-set.
However, in that case it would be enough to just keep the weight norm at any level. But they claim that there is an optimal level. So it can’t be the entire story they have in mind.
The existence of an optimal L2 norm makes no sense at all to me. The L2 norm is an extremely unnatural metric for neural networks. For instance, in ReLU networks if you multiply all the weights in one layer by β and all the weights in the layer above by 1/β, the output of the network doesn’t change at all, yet the L2 norm will have changed (the norm for those two layers will be β2|w1|2+1/β2|w2|2). In fact you can get any value for the L2 norm (above some minimum) you damn well please by just scaling the layers. An optimal average entropy of the output distribution over the course of training would make a hell of a lot more sense if this is somehow changing training dynamics.
It doesn’t matter that there are multiple networks with the same performance but different L2 norms. Instead, it suffices that the optimal network differs for different L2 norms, or that the gradient updates during training point in different directions when the network is L2 norms are constrained. Both are indeed true.
It also makes a lot of sense, if you think about it in terms of ordinary statistical learning theory. Assuming for a second that we’re sampling neural networks that achieve a certain train loss at a certain weight norm randomly, there’s some amount of regularization (IE, some small weight norm) that leads to the lowest test loss.
If the optimal norm is below the minimum you can achieve just by re-scaling, you are trading-off training set accuracy for weights with a smaller norm within each layer. It’s not that weird that the best known way of making this trade-off is by constrained optimization.
I think this is very similar to the hypothesis they have as well. But not sure if I understood it correctly, I think some parts of the paper are not as clear as they could be
I think this theory is probably part of the story, but it fails to explain Figure 2(b), where grokking happens in the presence of weight decay, even if you keep weight norm constant.
Hmm, I haven’t read the paper yet, but thinking about it, the easiest way to change the weight norm is just to multiply all weights by some factor α, but then in a network with ReLU activations and L layers, this would be completely equivalent to multiplying the output of the network by αL. In the overfitting regime where the network produces probability distributions where the mode is equal to the answer for all training tokens, the easiest way to decrease loss is just to multiply the output by a constant factor, essentially decreasing the entropy of the distribution forever, but this strategy fails at test-time because the mode is not equal to the answer there. So keeping the weight norm at some specified level might just be a way to prevent the network from taking the easy way towards decreasing training loss, and forcing it to find ways at constant-weight-norm to decrease loss, which would better generalize for the test-set.
That makes a lot of sense.
However, in that case it would be enough to just keep the weight norm at any level. But they claim that there is an optimal level. So it can’t be the entire story they have in mind.
The existence of an optimal L2 norm makes no sense at all to me. The L2 norm is an extremely unnatural metric for neural networks. For instance, in ReLU networks if you multiply all the weights in one layer by β and all the weights in the layer above by 1/β, the output of the network doesn’t change at all, yet the L2 norm will have changed (the norm for those two layers will be β2|w1|2+1/β2|w2|2). In fact you can get any value for the L2 norm (above some minimum) you damn well please by just scaling the layers. An optimal average entropy of the output distribution over the course of training would make a hell of a lot more sense if this is somehow changing training dynamics.
It doesn’t matter that there are multiple networks with the same performance but different L2 norms. Instead, it suffices that the optimal network differs for different L2 norms, or that the gradient updates during training point in different directions when the network is L2 norms are constrained. Both are indeed true.
It also makes a lot of sense, if you think about it in terms of ordinary statistical learning theory. Assuming for a second that we’re sampling neural networks that achieve a certain train loss at a certain weight norm randomly, there’s some amount of regularization (IE, some small weight norm) that leads to the lowest test loss.
If the optimal norm is below the minimum you can achieve just by re-scaling, you are trading-off training set accuracy for weights with a smaller norm within each layer. It’s not that weird that the best known way of making this trade-off is by constrained optimization.
I think this is very similar to the hypothesis they have as well. But not sure if I understood it correctly, I think some parts of the paper are not as clear as they could be
I think this theory is probably part of the story, but it fails to explain Figure 2(b), where grokking happens in the presence of weight decay, even if you keep weight norm constant.