Small point/question, Quintin—when you say that you “can fully avoid grokking on modular arithmetic”, in the colab notebook you linked to in that paragraph it looks like you just trained for 3e4 steps. Without explicit regularization, I wouldn’t have expected your network to generalize in that time (it might take 1e6 or 1e7 steps for networks to fully generalize). What point were you trying to make there? By “avoid grokking”, do you mean (1) avoid generalization or (2) eliminate the time delay between memorization and generalization. I’d be pretty interested if you achieved (2) while not using explicit regularization.
I mean (1). You can see as much in the figure displayed in the linked notebook:
Note the lack of decrease in the val loss.
I only train for 3e4 steps because that’s sufficient to reach generalization with implicit regularization. E.g., here’s the loss graph I get if I set the batch size down to 50:
Setting the learning rate to 7e-2 also allows for generalization within 3e4 steps (though not as stably):
The slingshot effect does take longer than 3e4 steps to generalize:
Honestly I’d be surprised if you could achieve (2) even with explicit regularization, specifically on the modular addition task.
(You can achieve it by initializing the token embeddings to those of a grokked network so that the representations are appropriately structured; I’m not allowing things like that.)
EDIT: Actually, Omnigrok does this by constraining the parameter norm. I suspect this is mostly making it very difficult for the network to strongly memorize the data—given the weight decay parameter the network “tries” to learn a high-param norm memorizing solution, but then repeatedly runs into the parameter norm constraint—and so creates a very strong reason for the network to learn the generalizing algorithm. But that should still count as normal regularization.
If you train on infinite data, I assume you’d not see a delay between training and testing, but you’d expect a non-monotonic accuracy curve that looks kind of like the test accuracy curve in the finite-data regime? So I assume infinite data is also cheating?
I expect a delay even in the infinite data case, I think?
Although I’m not quite sure what you mean by “infinite data” here—if the argument is that every data point will have been seen during training, then I agree that there won’t be any delay. But yes training on the test set (even via “we train on everything so there is no possible test set”) counts as cheating for this purpose.
Small point/question, Quintin—when you say that you “can fully avoid grokking on modular arithmetic”, in the colab notebook you linked to in that paragraph it looks like you just trained for 3e4 steps. Without explicit regularization, I wouldn’t have expected your network to generalize in that time (it might take 1e6 or 1e7 steps for networks to fully generalize). What point were you trying to make there? By “avoid grokking”, do you mean (1) avoid generalization or (2) eliminate the time delay between memorization and generalization. I’d be pretty interested if you achieved (2) while not using explicit regularization.
I mean (1). You can see as much in the figure displayed in the linked notebook:
Note the lack of decrease in the val loss.
I only train for 3e4 steps because that’s sufficient to reach generalization with implicit regularization. E.g., here’s the loss graph I get if I set the batch size down to 50:
Setting the learning rate to 7e-2 also allows for generalization within 3e4 steps (though not as stably):
The slingshot effect does take longer than 3e4 steps to generalize:
Huh those batch size and learning rate experiments are pretty interesting!
Honestly I’d be surprised if you could achieve (2) even with explicit regularization, specifically on the modular addition task.(You can achieve it by initializing the token embeddings to those of a grokked network so that the representations are appropriately structured; I’m not allowing things like that.)
EDIT: Actually, Omnigrok does this by constraining the parameter norm. I suspect this is mostly making it very difficult for the network to strongly memorize the data—given the weight decay parameter the network “tries” to learn a high-param norm memorizing solution, but then repeatedly runs into the parameter norm constraint—and so creates a very strong reason for the network to learn the generalizing algorithm. But that should still count as normal regularization.
If you train on infinite data, I assume you’d not see a delay between training and testing, but you’d expect a non-monotonic accuracy curve that looks kind of like the test accuracy curve in the finite-data regime? So I assume infinite data is also cheating?
I expect a delay even in the infinite data case, I think?
Although I’m not quite sure what you mean by “infinite data” here—if the argument is that every data point will have been seen during training, then I agree that there won’t be any delay. But yes training on the test set (even via “we train on everything so there is no possible test set”) counts as cheating for this purpose.