First, I’d like to note that I don’t see why faster convergence after changing the learning rate support either story. After initial memorization, the loss decreases by ~3 OOM. Regardless of what’s gaining on inside the network, it wouldn’t be surprising if raising the learning rate increased convergence.
Also, I think what’s actually going on here is weirder than either of our interpretations. I ran experiments where I kept the learning rate the same for the first 1000 steps, then increased it by 10x and 50x for the rest of the training.
Here is the accuracy curve with the default learning rate:
Here is the curve with 10x learning rate:
And here is the curve with 50x learning rate:
Note that increasing the learning rate doesn’t consistently increase validation convergence. The 50x run does reach convergence faster, but the 10x run doesn’t even reach it at all.
In fact, increasing the learning rate causes the training accuracy to fall to the validation accuracy, after which they begin to increase together (at least for a while). For the 10x increase, the training accuracy quickly diverges from the validation accuracy. In the 50x run, the training and validation accuracies move in tandem throughout the run.
Frederik’s results are broadly similar. If you mouse over the accuracy and loss graphs, you’ll see that
Training performance drops significantly immediately after the learning rate increases.
The losses and accuracies of the “5x” and “10x” lines correlate together pretty well between training/validation. In contrast, the losses and accuracies of the “default” lines don’t correlate strongly between training and testing.
I think that increasing the learning rate after memorization causes some sort of “mode shift” in the training process. It goes from:
First, learn shallow patterns that strongly overfit to the training data, then learn general patterns.
to:
Immediately learn general patterns that perform about equally well on the training and validation data.
In the case of my 10x run, I think it actually has two mode transitions, first from “shallow first” to “immediately general”, then another transition back to “shallow first”, and that’s why you see the training accuracy diverge from the validation accuracy again.
I think results like these make a certain amount of sense, given that higher learning rates are associated with better generalization in more standard settings.
Regardless of what’s gaining on inside the network, it wouldn’t be surprising if raising the learning rate increased convergence.
I’m kinda confused at your perspective on learning rates. I usually think of learning rates as being set to the maximum possible value such that training is still stable. So it would in fact be surprising if you could just 10x them to speed up convergence. (So an additional aspect of my prediction would be that you can’t 10x the learning rate at the beginning of training; if you could then it seems like the hyperparameters were chosen poorly and that should be fixed first.)
Indeed in your experiments at the moment you 10x the learning rate accuracy does in fact plummet! I’m a bit surprised it manages to recover, but you can see that the recovery is not nearly as stable as the original training before increasing the learning rate (this is even more obvious in the 50x case), and notably even the recovery for the training accuracy looks like it takes longer (1000-2000 steps) than the original increase in training accuracy (~400 steps).
I do think this suggests that you can’t in fact “just 10x the learning rate” once grokking starts, which seems like a hit to my story.
First, I’d like to note that I don’t see why faster convergence after changing the learning rate support either story. After initial memorization, the loss decreases by ~3 OOM. Regardless of what’s gaining on inside the network, it wouldn’t be surprising if raising the learning rate increased convergence.
Also, I think what’s actually going on here is weirder than either of our interpretations. I ran experiments where I kept the learning rate the same for the first 1000 steps, then increased it by 10x and 50x for the rest of the training.
Here is the accuracy curve with the default learning rate:
Here is the curve with 10x learning rate:
And here is the curve with 50x learning rate:
Note that increasing the learning rate doesn’t consistently increase validation convergence. The 50x run does reach convergence faster, but the 10x run doesn’t even reach it at all.
In fact, increasing the learning rate causes the training accuracy to fall to the validation accuracy, after which they begin to increase together (at least for a while). For the 10x increase, the training accuracy quickly diverges from the validation accuracy. In the 50x run, the training and validation accuracies move in tandem throughout the run.
Frederik’s results are broadly similar. If you mouse over the accuracy and loss graphs, you’ll see that
Training performance drops significantly immediately after the learning rate increases.
The losses and accuracies of the “5x” and “10x” lines correlate together pretty well between training/validation. In contrast, the losses and accuracies of the “default” lines don’t correlate strongly between training and testing.
I think that increasing the learning rate after memorization causes some sort of “mode shift” in the training process. It goes from:
First, learn shallow patterns that strongly overfit to the training data, then learn general patterns.
to:
Immediately learn general patterns that perform about equally well on the training and validation data.
In the case of my 10x run, I think it actually has two mode transitions, first from “shallow first” to “immediately general”, then another transition back to “shallow first”, and that’s why you see the training accuracy diverge from the validation accuracy again.
I think results like these make a certain amount of sense, given that higher learning rates are associated with better generalization in more standard settings.
I’m kinda confused at your perspective on learning rates. I usually think of learning rates as being set to the maximum possible value such that training is still stable. So it would in fact be surprising if you could just 10x them to speed up convergence. (So an additional aspect of my prediction would be that you can’t 10x the learning rate at the beginning of training; if you could then it seems like the hyperparameters were chosen poorly and that should be fixed first.)
Indeed in your experiments at the moment you 10x the learning rate accuracy does in fact plummet! I’m a bit surprised it manages to recover, but you can see that the recovery is not nearly as stable as the original training before increasing the learning rate (this is even more obvious in the 50x case), and notably even the recovery for the training accuracy looks like it takes longer (1000-2000 steps) than the original increase in training accuracy (~400 steps).
I do think this suggests that you can’t in fact “just 10x the learning rate” once grokking starts, which seems like a hit to my story.