I don’t know what the “real story” is, but let me point at some areas where I think we were confused. At the time, we had some sort of hand-wavy result in our appendix saying “something something weight norm ergo generalizing”. Similarly, concurrent work from Ziming Liu and others (Omnigrok) had another claim based on the norm of generalizing and memorizing solutions, as well as a claim that representation is important.
One issue is that our picture doesn’t consider learning dynamics that seem actually important here. For example, it seems that one of the mechanisms that may explain why weight decay seems to matter so much in the Omnigrok paper is because fixing the norm to be large leads to an effectively tiny learning rate when you use Adam (which normalizes the gradients to be of fixed scale), especiallywhen there’s a substantial radial component (which there is, when the init is too small or too big). This both probably explains why they found that training error was high when they constrain the weights to be sufficiently large in all their non-toy cases (see e.g. the mod add landscape below) and probably explains why we had difficulty using SGD+momentum (which, given our bad initialization, led to gradients that were way too big at some parts of the model especially since we didn’t sweep the learning rate very hard). [1]
There’s also some theoretical results from SLT-related folk about how generalizing circuits achieve lower train loss per parameter (i.e. have higher circuit efficiency) than memorizing circuits (at least for large p), which seems to be a part of the puzzle that neither our work nor the Omnigrok touched on—why is it that generalizing solutions have lower norm? IIRC one of our explanations was that weight decay “favored more distributed solutions” (somewhat false) and “it sure seems empirically true”, but we didn’t have anything better than that.
There was also the really basic idea of how a relu/gelu network may do multiplication (by piecewise linear approximations of x^2, or by using the quadratic region of the gelu for x^2), which (I think) was first described in late 2022 in Ekin Ayurek’s “Transformers can implement Sherman-Morris for closed-form ridge regression” paper? (That’s not the name, just the headline result.)
Part of the story for grokking in general may also be related to the Tensor Program results that claim the gradient on the embedding is too small relative to the gradient on other parts of the model, with standard init. (Also the embed at init is too small relative to the unembed.) Because the embed is both too small and do, there’s no representation learning going on, as opposed to just random feature regression (which overfits in the same way that regression on random features overfits absent regularization).
In our case, it turns out not to be true (because our network is tiny? because our weight decay is set aggressively at lamba=1?), since the weights that directly contribute to logits (W_E, W_U, W_O, W_V, W_in, W_out) all quickly converge to the same size (weight decay encourages spreading out weight norm between things you multiply together), while the weights that do not all converge to zero.
Bringing it back to the topic at hand: There’s often a lot more “small” confusions that remain, even after doing good toy models work. It’s not clear how much any of these confusions matter (and do any of the grokking results our paper, Ziming Liu et al, or the GDM grokking paper found matter?).
I don’t know what the “real story” is, but let me point at some areas where I think we were confused. At the time, we had some sort of hand-wavy result in our appendix saying “something something weight norm ergo generalizing”. Similarly, concurrent work from Ziming Liu and others (Omnigrok) had another claim based on the norm of generalizing and memorizing solutions, as well as a claim that representation is important.
One issue is that our picture doesn’t consider learning dynamics that seem actually important here. For example, it seems that one of the mechanisms that may explain why weight decay seems to matter so much in the Omnigrok paper is because fixing the norm to be large leads to an effectively tiny learning rate when you use Adam (which normalizes the gradients to be of fixed scale), especially when there’s a substantial radial component (which there is, when the init is too small or too big). This both probably explains why they found that training error was high when they constrain the weights to be sufficiently large in all their non-toy cases (see e.g. the mod add landscape below) and probably explains why we had difficulty using SGD+momentum (which, given our bad initialization, led to gradients that were way too big at some parts of the model especially since we didn’t sweep the learning rate very hard). [1]
There’s also some theoretical results from SLT-related folk about how generalizing circuits achieve lower train loss per parameter (i.e. have higher circuit efficiency) than memorizing circuits (at least for large p), which seems to be a part of the puzzle that neither our work nor the Omnigrok touched on—why is it that generalizing solutions have lower norm? IIRC one of our explanations was that weight decay “favored more distributed solutions” (somewhat false) and “it sure seems empirically true”, but we didn’t have anything better than that.
There was also the really basic idea of how a relu/gelu network may do multiplication (by piecewise linear approximations of x^2, or by using the quadratic region of the gelu for x^2), which (I think) was first described in late 2022 in Ekin Ayurek’s “Transformers can implement Sherman-Morris for closed-form ridge regression” paper? (That’s not the name, just the headline result.)
Part of the story for grokking in general may also be related to the Tensor Program results that claim the gradient on the embedding is too small relative to the gradient on other parts of the model, with standard init. (Also the embed at init is too small relative to the unembed.) Because the embed is both too small and do, there’s no representation learning going on, as opposed to just random feature regression (which overfits in the same way that regression on random features overfits absent regularization).
In our case, it turns out not to be true (because our network is tiny? because our weight decay is set aggressively at lamba=1?), since the weights that directly contribute to logits (W_E, W_U, W_O, W_V, W_in, W_out) all quickly converge to the same size (weight decay encourages spreading out weight norm between things you multiply together), while the weights that do not all converge to zero.
Bringing it back to the topic at hand: There’s often a lot more “small” confusions that remain, even after doing good toy models work. It’s not clear how much any of these confusions matter (and do any of the grokking results our paper, Ziming Liu et al, or the GDM grokking paper found matter?).
Haven’t checked, might do this later this week.