I updated the report with the training curves. Under default settings, 100% training accuracy is reached after 500 steps.
There is actually an overlap between the train/val curves going up. Might be an artifact of the simplicity of the task or that I didn’t properly split the dataset (e.g. x+y being in train and y+x being in val). I might run it again for a harder task to verify.
Huh, intriguing. Yeah, it might be worth running with a non-commutative function and seeing if it holds up—it seems like in the default setting the validation accuracy hits almost 0.5 once the training accuracy is 1, which is about what you’d get if you understood commutativity but nothing else about the function. So the “grokking” part is probably happening after that, i.e. at roughly the 1.5k steps location in the default setting.
Also interestingly, in the default setting for these new experiments, grokking happens in ~1000 steps while memorization happens in ~1500 steps, so the grokking is already faster than the memorization, in stark contrast to the graphs in the original post.
(This does depend on when you start the counter for grokking, as there’s a long period of slowly increasing validation accuracy. You could reasonably say grokking took ~2500 steps.)
Oh I thought figure 1 was S5 but it actually is modular division. I’ll give that a go..
Here are results for modular division. Not super sure what to make of them. Small increases in learning rate work, but so does just choosing a larger learning rate from the beginning. In fact, increasing lr to 5x from the beginning works super well but switching to 5x once grokking arguably starts just destroys any progress. 10x lr from the start does not work (nor when switching later)
So maybe the initial observation is more a general/global property of the loss landscape for the task and not of the particular region during grokking?
I updated the report with the training curves. Under default settings, 100% training accuracy is reached after 500 steps.
There is actually an overlap between the train/val curves going up. Might be an artifact of the simplicity of the task or that I didn’t properly split the dataset (e.g. x+y being in train and y+x being in val). I might run it again for a harder task to verify.
Huh, intriguing. Yeah, it might be worth running with a non-commutative function and seeing if it holds up—it seems like in the default setting the validation accuracy hits almost 0.5 once the training accuracy is 1, which is about what you’d get if you understood commutativity but nothing else about the function. So the “grokking” part is probably happening after that, i.e. at roughly the 1.5k steps location in the default setting.
So I ran some experiments for the permutation group S_5 with the task x o y = ?
Interestingly here increasing the learning rate just never works. I’m very confused.
Also interestingly, in the default setting for these new experiments, grokking happens in ~1000 steps while memorization happens in ~1500 steps, so the grokking is already faster than the memorization, in stark contrast to the graphs in the original post.
(This does depend on when you start the counter for grokking, as there’s a long period of slowly increasing validation accuracy. You could reasonably say grokking took ~2500 steps.)
Oh I thought figure 1 was S5 but it actually is modular division. I’ll give that a go..
Here are results for modular division. Not super sure what to make of them. Small increases in learning rate work, but so does just choosing a larger learning rate from the beginning. In fact, increasing lr to 5x from the beginning works super well but switching to 5x once grokking arguably starts just destroys any progress. 10x lr from the start does not work (nor when switching later)
So maybe the initial observation is more a general/global property of the loss landscape for the task and not of the particular region during grokking?
Yeah, that seems right, I think I’m basically at “no, you can’t just 10x the learning rate once grokking starts”.
Increasing regularization (weight decay in this instance) might rescue the ones which don’t work.
I tried increasing weight decay and increased batch sizes but so far no real success compared to 5x lr. Not going to investigate this further atm.