So, just don’t keep training a powerful AI past overfitting, and it won’t grok anything, right? Well, Nanda and Lieberum speculate that the reason it was difficult to figure out that grokking existed isn’t because it’s rare but because it’s omnipresent: smooth loss curves are the result of many new grokkings constantly being built atop the previous ones.
If the grokkings are happening all the time, why do you get double descent? Why wouldn’t the test loss just be a smooth curve?
Maybe the answer is something like:
The model is learning generalizable patterns and non-generalizable memorizations all the time.
Both the patterns and the memorizations fall along some distribution of discoverability (by gradient descent) based on their complexity.
The distribution of patterns is more fat-tailed than the distribution of memorizations — there are a bunch of easily-discoverable patterns, and many hard-to-discover patterns, while the memorizations are more clumped together in their discoverability. (We might expect this to be true, since all the input-output pairs would have a similar level of complexity in the grand scheme of things.)
Therefore, the model learns a bunch of easily-discoverable generalizable patterns first, leading to the first test-time loss descent.
Then the model gets to a point when it’s burned through all the easily-discoverable patterns, and it’s mostly learning memorizations. The crossover point where it starts learning more memorizations than patterns corresponds to the nadir of the first descent.
The model goes on for a while learning memorizations moreso than patterns. This is the overfitting regime.
Once it’s run out of memorizations to learn, and/or the regularization complexity penalty makes it start favoring patterns again, the second descent begins.
Or, put more simply:
The reason you get double descent is because the distribution of complexity / discoverability for generalizable patterns is wide, whereas the distribution of complexity / discoverability of memorizations is more narrow.
(If there weren’t any easy-to-discover patterns, you wouldn’t see the first test-time loss descent. And if there weren’t any patterns that are harder to discover than the memorizations, you wouldn’t see the second descent.)
(Not too surprising, since I had just read Lawrence’s comment, which summarizes the paper, when I made mine.)
In particular, the paper describes Type 1, Type 2, and Type 3 patterns, which correspond to my easy-to-discover patterns, memorizations, and hard-to-discover patterns:
In our model of grokking and double descent, there are three types of patterns learned at different speeds. Type 1 patterns are fast and generalize well (heuristics). Type 2 patterns are fast, though slower than Type 1, and generalize poorly (overfitting). Type 3 patterns are slow and generalize well.
The one thing I mention above that I don’t see in the paper is an explanation for why the Type 2 patterns would be intermediate in learnability between Type 1 and Type 3 patterns or why there would be a regime where they dominate (resulting in overfitting).
My proposed explanation is that, for any given task, the exact mappings from input to output will tend to have a characteristic complexity, which means that they will have a relatively narrow distribution of learnability. And that’s why models will often hit a regime where they’re mostly finding those patterns rather than Type 1, easy-to-learn heuristics (which they’ve exhausted) or Type 3, hard-to-learn rules (which they’re not discovering yet).
The authors do have an appendix section A.1 in the paper with the heading, “Heuristics, Memorization, and Slow Well-Generalizing”, but with “[TODO]”s in the text. Will be curious to see if they end up saying something similar to this point (about input-output memorizations tending to have a characteristic complexity) there.
If the grokkings are happening all the time, why do you get double descent? Why wouldn’t the test loss just be a smooth curve?
Maybe the answer is something like:
The model is learning generalizable patterns and non-generalizable memorizations all the time.
Both the patterns and the memorizations fall along some distribution of discoverability (by gradient descent) based on their complexity.
The distribution of patterns is more fat-tailed than the distribution of memorizations — there are a bunch of easily-discoverable patterns, and many hard-to-discover patterns, while the memorizations are more clumped together in their discoverability. (We might expect this to be true, since all the input-output pairs would have a similar level of complexity in the grand scheme of things.)
Therefore, the model learns a bunch of easily-discoverable generalizable patterns first, leading to the first test-time loss descent.
Then the model gets to a point when it’s burned through all the easily-discoverable patterns, and it’s mostly learning memorizations. The crossover point where it starts learning more memorizations than patterns corresponds to the nadir of the first descent.
The model goes on for a while learning memorizations moreso than patterns. This is the overfitting regime.
Once it’s run out of memorizations to learn, and/or the regularization complexity penalty makes it start favoring patterns again, the second descent begins.
Or, put more simply:
The reason you get double descent is because the distribution of complexity / discoverability for generalizable patterns is wide, whereas the distribution of complexity / discoverability of memorizations is more narrow.
(If there weren’t any easy-to-discover patterns, you wouldn’t see the first test-time loss descent. And if there weren’t any patterns that are harder to discover than the memorizations, you wouldn’t see the second descent.)
Does that sound plausible as an explanation?
After reading through the Unifying Grokking and Double Descent paper that LawrenceC linked, it sounds like I’m mostly saying the same thing as what’s in the paper.
(Not too surprising, since I had just read Lawrence’s comment, which summarizes the paper, when I made mine.)
In particular, the paper describes Type 1, Type 2, and Type 3 patterns, which correspond to my easy-to-discover patterns, memorizations, and hard-to-discover patterns:
The one thing I mention above that I don’t see in the paper is an explanation for why the Type 2 patterns would be intermediate in learnability between Type 1 and Type 3 patterns or why there would be a regime where they dominate (resulting in overfitting).
My proposed explanation is that, for any given task, the exact mappings from input to output will tend to have a characteristic complexity, which means that they will have a relatively narrow distribution of learnability. And that’s why models will often hit a regime where they’re mostly finding those patterns rather than Type 1, easy-to-learn heuristics (which they’ve exhausted) or Type 3, hard-to-learn rules (which they’re not discovering yet).
The authors do have an appendix section A.1 in the paper with the heading, “Heuristics, Memorization, and Slow Well-Generalizing”, but with “[TODO]”s in the text. Will be curious to see if they end up saying something similar to this point (about input-output memorizations tending to have a characteristic complexity) there.