Which of these theories [...] can predict the same “four novel predictions about grokking” yours did? The relative likelihoods are what matters for updates after all.
I disagree with the implicit view on how science works. When you are a computationally bounded reasoner, you work with partial hypotheses, i.e. hypotheses that only make predictions on a small subset of possible questions, and just shrug at other questions. This is mostly what happens with the other theories:
Difficulty of representation learning: Shrugs at our prediction about Cmem / Cgen efficiencies, anti-predicts ungrokking (since in that case the representation has already been learned), shrugs at semi-grokking.
Scale of parameters at initialisation: Shrugs at all of our predictions. If you interpret it as making a strong claim that scale of parameters at initialisation is the crucial thing (i.e. other things mostly don’t matter) then it anti-predicts semi-grokking.
Spikes in loss / slingshots: Shrugs at all of our predictions.
Random walks among optimal solutions: Shrugs at our prediction about Cmem / Cgen efficiencies. I’m not sure what this theory says about what happens after you hit the generalising solution—can you then randomly walk away from the generalising solution? If yes, then it predicts that if you train for a long enough time without changing the dataset, a grokked network will ungrok (false in our experiments, and we often trained for much longer than time to grok); if no then it anti-predicts ungrokking and semi-grokking.
Simplicity of the generalising solution: This is our explanation. Our paper is basically an elaboration, formalization, and confirmation of Nanda et al’s theory, as we allude to in the next sentence after the one you quoted.
how does this theory explain other grokking related pheonmena e.g. Omni-Grok?
My speculation for Omni-Grok in particular is that in settings like MNIST you already have two of the ingredients for grokking (that there are both memorising and generalising solutions, and that the generalising solution is more efficient), and then having large parameter norms at initialisation provides the third ingredient (generalising solutions are learned more slowly), for some reason I still don’t know.
Happy to speculate on other grokking phenomena as well (though I don’t think there are many others?)
And how do things change as you increase parameter count?
We haven’t investigated this, but I’d pretty strongly predict that there mostly aren’t major qualitative changes. (The one exception is semi-grokking; there’s a theoretical reason to expect it may sometimes not occur, and also in practice it can be quite hard to elicit.)
I expect there would be quantitative changes (e.g. maybe the value of Dcrit changes, maybe the time taken to learn Cgen changes). Sufficiently big changes in Dcrit might mean you don’t see the phenomena on modular addition any more, but I’d still expect to see them in more complicated tasks that exhibit grokking.
I’d be interested in investigations that got into these quantitative questions (in addition to the above, there’s also things like “quantitatively, how does the strength of weight decay affect the time for Cgen to be learned?”, and many more).
I disagree with the implicit view on how science works. When you are a computationally bounded reasoner, you work with partial hypotheses, i.e. hypotheses that only make predictions on a small subset of possible questions, and just shrug at other questions.
Implictly, I thought if a you have a partial hypothesis of grokking, then if it shrugs at an grokking related phenomena it should be penalized. Unless by “shrugs” you mean the details of what the partial hypothesis says in this particular case are still being worked out. But in that case, confirming the partial hypothesis doesn’t say anything yet about some phenomena is still useful info. I’m fairly sure this belief was what generated my question.
This is mostly what happens with the other theories
Thank you for going through the theories and checking what they have to say. That was helpful to me.
I’d be interested in investigations that got into these quantitative questions
Do you have any plans to do this? How much time do you think it would take? And do you have any predictions for what should happen in these cases?
Unless by “shrugs” you mean the details of what the partial hypothesis says in this particular case are still being worked out.
Yes, that’s what I mean.
I do agree that it’s useful to know whether a partial hypothesis says anything or not; overall I think this is good info to know / ask for. I think I came off as disagreeing more strongly than I actually did, sorry about that.
Do you have any plans to do this?
No, we’re moving on to other work: this took longer than we expected, and was less useful for alignment than we hoped (though that part wasn’t that unexpected, from the start we expected “science of deep learning” to be more hits-based, or to require significant progress before it actually became useful for practical proposals).
How much time do you think it would take?
Actually running the experiments should be pretty straightforward, I’d expect we could do them in a week given our codebase, possibly even a day. Others might take some time to set up a good codebase but I’d still be surprised if it took a strong engineer longer than two weeks to get some initial results. This gets you observations like “under the particular settings we chose, D_crit tends to increase / decrease as the number of layers increases”.
The hard part is then interpreting those results and turning them into something more generalizable—including handling confounds. For example, maybe for some reason the principled thing to do is to reduce the learning rate as you increase layers, and once you do that your observation reverses—this is a totally made up example but illustrates the kind of annoying things that come up when doing this sort of research, that prevent you from saying anything general. I don’t know how long it would take if you want to include that; it could be quite a while (e.g. months or years).
And do you have any predictions for what should happen in these cases?
Not really. I’ve learned from experience not to try to make quantitative predictions yet. We tried to make some theory-inspired quantitative predictions in the settings we studied, and they fell pretty flat.
For example, in our minimal model in Section 3 we have a hyperparameter κ that determines how param norm and logits scale together—initially, that was our guess of what would happen in practice (i.e. we expected circuit param norm <> circuit logits to obey a power law relationship in actual grokking settings). But basically every piece of evidence we got seemed to falsify that hypothesis (e.g. Figure 3 in the paper).
(I say “seemed to falsify” because it’s still possible that we’re just failing to deal with confounders in some way, or measuring something that isn’t exactly what we want to measure. For example, Figure 3 logits are not of the Mem circuit in actual grokking setups, but rather the logits produced by networks trained on random labels—maybe there’s a relevant difference between these.)
My speculation for Omni-Grok in particular is that in settings like MNIST you already have two of the ingredients for grokking (that there are both memorising and generalising solutions, and that the generalising solution is more efficient), and then having large parameter norms at initialisation provides the third ingredient (generalising solutions are learned more slowly), for some reason I still don’t know.
Higher weight norm means lower effective learning rate with Adam, no? In that paper they used a constant learning rate across weight norms, but Adam tries to normalize the gradients to be of size 1 per paramter, regardless of the size of the weights. So the weights change more slowly with larger initializations (especially since they constrain the weights to be of fixed norm by projecting after the Adam step).
Sounds plausible, but why does this differentially impact the generalizing algorithm over the memorizing algorithm?
Perhaps under normal circumstances both are learned so fast that you just don’t notice that one is slower than the other, and this slows both of them down enough that you can see the difference?
I disagree with the implicit view on how science works. When you are a computationally bounded reasoner, you work with partial hypotheses, i.e. hypotheses that only make predictions on a small subset of possible questions, and just shrug at other questions. This is mostly what happens with the other theories:
Difficulty of representation learning: Shrugs at our prediction about Cmem / Cgen efficiencies, anti-predicts ungrokking (since in that case the representation has already been learned), shrugs at semi-grokking.
Scale of parameters at initialisation: Shrugs at all of our predictions. If you interpret it as making a strong claim that scale of parameters at initialisation is the crucial thing (i.e. other things mostly don’t matter) then it anti-predicts semi-grokking.
Spikes in loss / slingshots: Shrugs at all of our predictions.
Random walks among optimal solutions: Shrugs at our prediction about Cmem / Cgen efficiencies. I’m not sure what this theory says about what happens after you hit the generalising solution—can you then randomly walk away from the generalising solution? If yes, then it predicts that if you train for a long enough time without changing the dataset, a grokked network will ungrok (false in our experiments, and we often trained for much longer than time to grok); if no then it anti-predicts ungrokking and semi-grokking.
Simplicity of the generalising solution: This is our explanation. Our paper is basically an elaboration, formalization, and confirmation of Nanda et al’s theory, as we allude to in the next sentence after the one you quoted.
My speculation for Omni-Grok in particular is that in settings like MNIST you already have two of the ingredients for grokking (that there are both memorising and generalising solutions, and that the generalising solution is more efficient), and then having large parameter norms at initialisation provides the third ingredient (generalising solutions are learned more slowly), for some reason I still don’t know.
Happy to speculate on other grokking phenomena as well (though I don’t think there are many others?)
We haven’t investigated this, but I’d pretty strongly predict that there mostly aren’t major qualitative changes. (The one exception is semi-grokking; there’s a theoretical reason to expect it may sometimes not occur, and also in practice it can be quite hard to elicit.)
I expect there would be quantitative changes (e.g. maybe the value of Dcrit changes, maybe the time taken to learn Cgen changes). Sufficiently big changes in Dcrit might mean you don’t see the phenomena on modular addition any more, but I’d still expect to see them in more complicated tasks that exhibit grokking.
I’d be interested in investigations that got into these quantitative questions (in addition to the above, there’s also things like “quantitatively, how does the strength of weight decay affect the time for Cgen to be learned?”, and many more).
Implictly, I thought if a you have a partial hypothesis of grokking, then if it shrugs at an grokking related phenomena it should be penalized. Unless by “shrugs” you mean the details of what the partial hypothesis says in this particular case are still being worked out. But in that case, confirming the partial hypothesis doesn’t say anything yet about some phenomena is still useful info. I’m fairly sure this belief was what generated my question.
Thank you for going through the theories and checking what they have to say. That was helpful to me.
Do you have any plans to do this? How much time do you think it would take? And do you have any predictions for what should happen in these cases?
Yes, that’s what I mean.
I do agree that it’s useful to know whether a partial hypothesis says anything or not; overall I think this is good info to know / ask for. I think I came off as disagreeing more strongly than I actually did, sorry about that.
No, we’re moving on to other work: this took longer than we expected, and was less useful for alignment than we hoped (though that part wasn’t that unexpected, from the start we expected “science of deep learning” to be more hits-based, or to require significant progress before it actually became useful for practical proposals).
Actually running the experiments should be pretty straightforward, I’d expect we could do them in a week given our codebase, possibly even a day. Others might take some time to set up a good codebase but I’d still be surprised if it took a strong engineer longer than two weeks to get some initial results. This gets you observations like “under the particular settings we chose, D_crit tends to increase / decrease as the number of layers increases”.
The hard part is then interpreting those results and turning them into something more generalizable—including handling confounds. For example, maybe for some reason the principled thing to do is to reduce the learning rate as you increase layers, and once you do that your observation reverses—this is a totally made up example but illustrates the kind of annoying things that come up when doing this sort of research, that prevent you from saying anything general. I don’t know how long it would take if you want to include that; it could be quite a while (e.g. months or years).
Not really. I’ve learned from experience not to try to make quantitative predictions yet. We tried to make some theory-inspired quantitative predictions in the settings we studied, and they fell pretty flat.
For example, in our minimal model in Section 3 we have a hyperparameter κ that determines how param norm and logits scale together—initially, that was our guess of what would happen in practice (i.e. we expected circuit param norm <> circuit logits to obey a power law relationship in actual grokking settings). But basically every piece of evidence we got seemed to falsify that hypothesis (e.g. Figure 3 in the paper).
(I say “seemed to falsify” because it’s still possible that we’re just failing to deal with confounders in some way, or measuring something that isn’t exactly what we want to measure. For example, Figure 3 logits are not of the Mem circuit in actual grokking setups, but rather the logits produced by networks trained on random labels—maybe there’s a relevant difference between these.)
Higher weight norm means lower effective learning rate with Adam, no? In that paper they used a constant learning rate across weight norms, but Adam tries to normalize the gradients to be of size 1 per paramter, regardless of the size of the weights. So the weights change more slowly with larger initializations (especially since they constrain the weights to be of fixed norm by projecting after the Adam step).
Sounds plausible, but why does this differentially impact the generalizing algorithm over the memorizing algorithm?
Perhaps under normal circumstances both are learned so fast that you just don’t notice that one is slower than the other, and this slows both of them down enough that you can see the difference?