Thanks, this is awesome! Especially cool that you were able to reverse-engineer the learned algorithm! And the theory/analysis seems great too.
[Just thinking aloud my confusion, not important]
—Not sure I understand the difference between random walk hypothesis and lottery ticket inspired hypothesis. In both cases there’s a phase transition caused by accelerating gradients once weak versions of all the components of the nonmodular circuit are in place.
—The central piece of evidence you bring up is what exactly… that the effectiveness of the nonmodular circuit rises before the phase transition begins. I guess I’m a bit confused about how this works. If I imagine a circuit consisting of a modular memorized answer part and another part, the generalizing bit, that actually computes the answer… if the second part is really crappy/weak/barely there, I don’t see how making it stronger is incentivised. Making it stronger means more weights so the regularization should push against it, UNLESS you can simultaneously delete or dampen weights from the memorized answer part, right? But because the generalizing part is so weak, if you delete or dampen weights from the memorized answer part, won’t loss go up? Because the generalizing part isn’t ready yet to take over the job of the memorized answers?
Thanks! I agree that they’re pretty hard to distinguish, and evidence between them is fairly weak—it’s hard to distinguish between a winning lottery ticket at initialisation vs one stumbled upon within the first 200 steps, say.
My favourite piece of evidence is [this video from Eric Michaud](https://twitter.com/ericjmichaud_/status/1559305105521926144) - we know that the first 2 principle components of the embedding form a circle at the end of training. But if we fix the axes at the end of training, and project the embedding at the start of training, it’s pretty circle-y
Making it stronger means more weights so the regularization should push against it, UNLESS you can simultaneously delete or dampen weights from the memorized answer part, right?
I think this does happen (and is very surprising to me!). If you look at the excluded loss section, I ablate the model’s ability to use one component of the generalising algorithm, in a way that shouldn’t affect the memorising algorithm (much), and see the damage of this ablation rise smoothly over training. I hypothesise that it’s dampening memorisation weights simulatenously, though haven’t dug deep enough to be confident. Regardless, it clearly seems to be doing some kind of interpolation—I have a lot of progress measures that all (both qualitatively and quantitatively) show clear progress towards generalisation pre grokking.
A maximally sparse neural net layer (k=1 max, only one neuron active) effectively is just a simple input->output key/value map and thus can only memorize. It can at best learn to associate each input pattern with one specific output pattern, no more, no less (and can perfectly trivial overfit any dataset of N examples by using N neurons and NI + NO memory, just like a map/table in CS).
We can get some trivial compression if there are redundant input-> output mappings, but potentially much larger gains by slowly relaxing that sparsity constraint and allowing more neurons to be simultaneously active to provide more opportunities to compress the function. With k=2 for example and N=D/2, each neuron now responds to exactly two different examples and must share the input->output mapping with one other neuron—by specializing on different common subset patterns for example.
At the extreme of compression we have a neural circuit which computes some algorithm that fits the data well and is likely more dense. In the continuous circuit space there are always interpolations between those circuits and memorization circuits.
The training path is always continuous, thus it necessarily interpolates smoothly between some overfit memorization and the generalizing (nonmodular) circuit solution. But that shouldn’t be too surprising—a big circuit can always be recursively decomposed down to smaller elementary pieces, and each elementary circuit is always logically equivalent to not a single unique lookup table, but an infinite set of overparameterized equivalent redundant lookup tables.
So it just has to find one of the many redundant lookuptable (memorization) solutions first, then smoothly remove redundancy of the lookup tables. The phase transitions likely arise due to semi-combinatoric dependencies between layers (and those probably become more pronounced with increasing depth complexity).
Thanks, this is awesome! Especially cool that you were able to reverse-engineer the learned algorithm! And the theory/analysis seems great too.
[Just thinking aloud my confusion, not important] —Not sure I understand the difference between random walk hypothesis and lottery ticket inspired hypothesis. In both cases there’s a phase transition caused by accelerating gradients once weak versions of all the components of the nonmodular circuit are in place. —The central piece of evidence you bring up is what exactly… that the effectiveness of the nonmodular circuit rises before the phase transition begins. I guess I’m a bit confused about how this works. If I imagine a circuit consisting of a modular memorized answer part and another part, the generalizing bit, that actually computes the answer… if the second part is really crappy/weak/barely there, I don’t see how making it stronger is incentivised. Making it stronger means more weights so the regularization should push against it, UNLESS you can simultaneously delete or dampen weights from the memorized answer part, right? But because the generalizing part is so weak, if you delete or dampen weights from the memorized answer part, won’t loss go up? Because the generalizing part isn’t ready yet to take over the job of the memorized answers?
Thanks! I agree that they’re pretty hard to distinguish, and evidence between them is fairly weak—it’s hard to distinguish between a winning lottery ticket at initialisation vs one stumbled upon within the first 200 steps, say.
My favourite piece of evidence is [this video from Eric Michaud](https://twitter.com/ericjmichaud_/status/1559305105521926144) - we know that the first 2 principle components of the embedding form a circle at the end of training. But if we fix the axes at the end of training, and project the embedding at the start of training, it’s pretty circle-y
I think this does happen (and is very surprising to me!). If you look at the excluded loss section, I ablate the model’s ability to use one component of the generalising algorithm, in a way that shouldn’t affect the memorising algorithm (much), and see the damage of this ablation rise smoothly over training. I hypothesise that it’s dampening memorisation weights simulatenously, though haven’t dug deep enough to be confident. Regardless, it clearly seems to be doing some kind of interpolation—I have a lot of progress measures that all (both qualitatively and quantitatively) show clear progress towards generalisation pre grokking.
A maximally sparse neural net layer (k=1 max, only one neuron active) effectively is just a simple input->output key/value map and thus can only memorize. It can at best learn to associate each input pattern with one specific output pattern, no more, no less (and can perfectly trivial overfit any dataset of N examples by using N neurons and NI + NO memory, just like a map/table in CS).
We can get some trivial compression if there are redundant input-> output mappings, but potentially much larger gains by slowly relaxing that sparsity constraint and allowing more neurons to be simultaneously active to provide more opportunities to compress the function. With k=2 for example and N=D/2, each neuron now responds to exactly two different examples and must share the input->output mapping with one other neuron—by specializing on different common subset patterns for example.
At the extreme of compression we have a neural circuit which computes some algorithm that fits the data well and is likely more dense. In the continuous circuit space there are always interpolations between those circuits and memorization circuits.
The training path is always continuous, thus it necessarily interpolates smoothly between some overfit memorization and the generalizing (nonmodular) circuit solution. But that shouldn’t be too surprising—a big circuit can always be recursively decomposed down to smaller elementary pieces, and each elementary circuit is always logically equivalent to not a single unique lookup table, but an infinite set of overparameterized equivalent redundant lookup tables.
So it just has to find one of the many redundant lookuptable (memorization) solutions first, then smoothly remove redundancy of the lookup tables. The phase transitions likely arise due to semi-combinatoric dependencies between layers (and those probably become more pronounced with increasing depth complexity).