I wonder if grokking is evidence for, or against, the Mignard et al view that SGD on big neural nets is basically a faster approximation of rejection sampling. Here’s an argument that it’s evidence against:
--Either the “grokked algorithm circuit” is simpler, or not simpler, than the “memorization circuit.”
--If it’s simpler, then rejection sampling would reach the grokked algorithm circuit prior to reaching the memorization circuit, which is not what we see.
--If it’s not simpler, then rejection sampling would briefly stumble across the grokked algorithm circuit eventually but immediately return to the memorization circuit.
OTOH maybe Mignard could reply that indeed, for small neural nets like these ones SGD is not merely an approximation of rejection sampling but rather meanders a lot, creating a situation where more complex circuits (the memorization ones) can have broader basins of attraction than simpler circuits (the grokked algorithm). But eventually SGD randomly jumps its way to the simpler circuit and then stays there. idk.
I feel like everyone is taking the SGD = rejection sampling view way too seriously. From the Mingard et al paper:
We argue here that the inductive bias found in DNNs trained by SGD or related optimisers, is, to first order, determined by the parameter-function map of an untrained DNN. While on a log scale we find PSGD(f|S) ≈ PB(f|S) there are also measurable second order deviations that are sensitive to hyperparameter tuning and optimiser choice.
The first order effect is what lets you conclude that when you ask GPT-3 a novel question like “how many bonks are in a quoit”, that it has never been trained on, you can expect that it won’t just start stringing characters together in a random way, but will probably respond with English words.
The second order effects could be what tells you whether or not it is going to respond with “there are three bonks in a quoit” or “that’s a nonsense question”. (Or maybe not! Maybe random sampling has a specific strong posterior there, and SGD does too! But it seems hard to know one way or the other.) Most alignment-relevant properties seem like they are in this class.
Grokking occurs in a weird special case where it seems there’s ~one answer that generalizes well and has much higher prior, and everything else is orders of magnitude less likely. I don’t really see why you should expect that results on MNIST should generalize to this situation.
Thanks! I’m not sure I understand your argument, but I think that’s my fault rather than yours, since tbh I don’t fully understand the Mingard et al paper itself, only its conclusion.
I wonder if grokking is evidence for, or against, the Mignard et al view that SGD on big neural nets is basically a faster approximation of rejection sampling. Here’s an argument that it’s evidence against:
--Either the “grokked algorithm circuit” is simpler, or not simpler, than the “memorization circuit.”
--If it’s simpler, then rejection sampling would reach the grokked algorithm circuit prior to reaching the memorization circuit, which is not what we see.
--If it’s not simpler, then rejection sampling would briefly stumble across the grokked algorithm circuit eventually but immediately return to the memorization circuit.
OTOH maybe Mignard could reply that indeed, for small neural nets like these ones SGD is not merely an approximation of rejection sampling but rather meanders a lot, creating a situation where more complex circuits (the memorization ones) can have broader basins of attraction than simpler circuits (the grokked algorithm). But eventually SGD randomly jumps its way to the simpler circuit and then stays there. idk.
I feel like everyone is taking the SGD = rejection sampling view way too seriously. From the Mingard et al paper:
The first order effect is what lets you conclude that when you ask GPT-3 a novel question like “how many bonks are in a quoit”, that it has never been trained on, you can expect that it won’t just start stringing characters together in a random way, but will probably respond with English words.
The second order effects could be what tells you whether or not it is going to respond with “there are three bonks in a quoit” or “that’s a nonsense question”. (Or maybe not! Maybe random sampling has a specific strong posterior there, and SGD does too! But it seems hard to know one way or the other.) Most alignment-relevant properties seem like they are in this class.
Grokking occurs in a weird special case where it seems there’s ~one answer that generalizes well and has much higher prior, and everything else is orders of magnitude less likely. I don’t really see why you should expect that results on MNIST should generalize to this situation.
Thanks! I’m not sure I understand your argument, but I think that’s my fault rather than yours, since tbh I don’t fully understand the Mingard et al paper itself, only its conclusion.