Thanks a lot for writing up this post! This felt much clearer and more compelling to me than the earlier versions I’d heard, and I broadly buy that this is a lot of what was going on with the phase transitions in my grokking work.
The algebra in the rank-1 learning section was pretty dense and not how I would have phrased it, so here’s my attempt to put it in my own language:
We want to fit to some fixed rank 1 matrix C=abT, with two learned vectors x,y, forming Z=xyT. Our objective function is L=|C−Z|2=|C2|−2(C.Z)+|C2|. Rank one matrix facts - |C2|=∑i,ja2ib2j=|a|2|b|2 and C.Z=∑i,jaixibjyj=∑iaixi∑jbjyj=(a.x)(b.y).
So our loss function is now L=(x.x)(y.y)−2(a.x)(b.y). So what’s the derivative with respect to x? This is the same question as “what’s the best linear approximation to how does this function change when x→x+ϵ”. Here we can just directly read this off as −∇xL=2((b.y)a−(y.y)x)
The second term (y.y)x is an exponential decay term, assuming the size of y is constant (in practice this is probably a good enough assumption). The first term (b.y)a is the actual signal, moving along the correct direction, but is proportional to how well the other part is doing, which starts bad and then increases, creating the self-reinforcing properties that make it initially start slow then increase.
Another rephrasing—x consists of a component in the correct direction (a), and the rest of x is irrelevant. Ditto y. The components in the correct directions reinforce each other, and all components experience exponential-ish decay, because MSE loss wants everything not actively contributing to be small. At the start, the irrelevant components are way bigger (because they’re in the rank 99 orthogonal subspace to a), and they rapidly decay, while the correct component slowly grows. This is a slight decrease in loss, but mostly a plateau. Then once the irrelevant component is small and the correct component has gotten bigger, the correct signal dominates. Eventually, the exponential decay is strong enough in the correct direction to balance out the incentive for future growth.
Generalising to higher dimensional subspaces, “correct and incorrect” component corresponds to the restriction to the subspace of the a terms, and to the complement of that, but so long as the subspace is low rank, “irrelevant component bigger so it initially dominates” still holds.
My remaining questions—I’d love to hear takes:
The rank 2 case feels qualitatively different from the rank 1 case because there’s now a symmetry to break—will the first component of Z match the first or second component of C? Intuitively, breaking symmetries will create another S-shaped vibe, because the signal for getting close to the midpoint is high, while the signal to favour either specific component is lower.
What happens in a cross-entropy loss style setup, rather than MSE loss? IMO cross-entropy loss is a better analogue to real networks. Though I’m confused about the right way to model an internal sub-circuit of the model. I think the exponential decay term just isn’t there?
How does this interact with weight decay? This seems to give an intrinsic exponential decay to everything
How does this interact with softmax? Intuitively, softmax feels “S-curve-ey”
How does this with interact with Adam? In particular, Adam gets super messy because you can’t just disentangle things
I agree with both of your rephrasings and I think both add useful intuition!
Regarding rank 2, I don’t see any difference in behavior from rank 1 other than the “bump” in alignment that Lawrence mentioned. Here’s an example:
This doesn’t happen in all rank-2 cases but is relatively common. I think usually each vector grows primarily towards 1 or the other target. If two vectors grow towards the same target then you get this bump where one of them has to back off and align more towards a different target [at least that’s my current understanding, see my reply to Lawrence for more detail!].
What happens in a cross-entropy loss style setup, rather than MSE loss? IMO cross-entropy loss is a better analogue to real networks. Though I’m confused about the right way to model an internal sub-circuit of the model. I think the exponential decay term just isn’t there?
What does a cross-entropy setup look like here? I’m just not sure how to map this toy model onto that loss (or vice-versa).
How does this interact with weight decay? This seems to give an intrinsic exponential decay to everything
Agreed! I expect weight decay to (1) make the converged solution not actually minimize the original loss (because the weight decay keeps tugging it towards lower norms) and (2) accelerate the initial decay. I don’t think I expect any other changes.
How does this interact with softmax? Intuitively, softmax feels “S-curve-ey”
I’m not sure! Do you have a setup in mind?
How does this with interact with Adam? In particular, Adam gets super messy because you can’t just disentangle things. Even worse, how does it interact with AdamW?
I agree this breaks my theoretical intuition. Experimentally most of the phenomenology is the same, except that the full-rank (rank 100) case regains a plateau.
Here’s rank 2:
rank 10:
(maybe there’s more ‘bump’ formation here than with SGD?)
rank 100:
It kind of looks like the plateau has returned! And this replicates across every rank 100 example I tried, e.g.
The plateau corresponds to a period with a lot of bump formation. If bumps really are a sign of vectors competing to represent different chunks of subspace then maybe this says that Adam produces more such competition (maybe by making different vectors learn at more similar rates?).
I’d be curious if you have any intuition about this!
The plateau corresponds to a period with a lot of bump formation. If bumps really are a sign of vectors competing to represent different chunks of subspace then maybe this says that Adam produces more such competition (maybe by making different vectors learn at more similar rates?).
I caution against over-interpreting the results of single runs—I think there’s a good chance the number of bumps varies significantly by random seed.
What happens in a cross-entropy loss style setup, rather than MSE loss? IMO cross-entropy loss is a better analogue to real networks. Though I’m confused about the right way to model an internal sub-circuit of the model. I think the exponential decay term just isn’t there?
There’s lots of ways to do this, but the obvious way is to flatten C and Z and treat them as logits.
(Adam Jermyn ninja’ed my rank 2 results as I forgot to refresh, lol)
Weight decay just means the gradient becomes −∇xL=2(⟨b,y⟩a−⟨y,y⟩x)−λx, which effectively “extends” the exponential phase. It’s pretty easy to confirm that this is the case:
How does this with interact with Adam? In particular, Adam gets super messy because you can’t just disentangle things. Even worse, how does it interact with AdamW?
Should be trivial to modify my code to use AdamW, just replace SGD with Adam on line 33.
EDIT: ran the experiments for rank 1, they seem a bit different than Adam Jermyn’s results—it looks like AdamW just accelerates things?
Thanks a lot for writing up this post! This felt much clearer and more compelling to me than the earlier versions I’d heard, and I broadly buy that this is a lot of what was going on with the phase transitions in my grokking work.
The algebra in the rank-1 learning section was pretty dense and not how I would have phrased it, so here’s my attempt to put it in my own language:
We want to fit to some fixed rank 1 matrix C=abT, with two learned vectors x,y, forming Z=xyT. Our objective function is L=|C−Z|2=|C2|−2(C.Z)+|C2|. Rank one matrix facts - |C2|=∑i,ja2ib2j=|a|2|b|2 and C.Z=∑i,jaixibjyj=∑iaixi∑jbjyj=(a.x)(b.y).
So our loss function is now L=(x.x)(y.y)−2(a.x)(b.y). So what’s the derivative with respect to x? This is the same question as “what’s the best linear approximation to how does this function change when x→x+ϵ”. Here we can just directly read this off as −∇xL=2((b.y)a−(y.y)x)
The second term (y.y)x is an exponential decay term, assuming the size of y is constant (in practice this is probably a good enough assumption). The first term (b.y)a is the actual signal, moving along the correct direction, but is proportional to how well the other part is doing, which starts bad and then increases, creating the self-reinforcing properties that make it initially start slow then increase.
Another rephrasing—x consists of a component in the correct direction (a), and the rest of x is irrelevant. Ditto y. The components in the correct directions reinforce each other, and all components experience exponential-ish decay, because MSE loss wants everything not actively contributing to be small. At the start, the irrelevant components are way bigger (because they’re in the rank 99 orthogonal subspace to a), and they rapidly decay, while the correct component slowly grows. This is a slight decrease in loss, but mostly a plateau. Then once the irrelevant component is small and the correct component has gotten bigger, the correct signal dominates. Eventually, the exponential decay is strong enough in the correct direction to balance out the incentive for future growth.
Generalising to higher dimensional subspaces, “correct and incorrect” component corresponds to the restriction to the subspace of the a terms, and to the complement of that, but so long as the subspace is low rank, “irrelevant component bigger so it initially dominates” still holds.
My remaining questions—I’d love to hear takes:
The rank 2 case feels qualitatively different from the rank 1 case because there’s now a symmetry to break—will the first component of Z match the first or second component of C? Intuitively, breaking symmetries will create another S-shaped vibe, because the signal for getting close to the midpoint is high, while the signal to favour either specific component is lower.
What happens in a cross-entropy loss style setup, rather than MSE loss? IMO cross-entropy loss is a better analogue to real networks. Though I’m confused about the right way to model an internal sub-circuit of the model. I think the exponential decay term just isn’t there?
How does this interact with weight decay? This seems to give an intrinsic exponential decay to everything
How does this interact with softmax? Intuitively, softmax feels “S-curve-ey”
How does this with interact with Adam? In particular, Adam gets super messy because you can’t just disentangle things
Even worse, how does it interact with AdamW?
I agree with both of your rephrasings and I think both add useful intuition!
Regarding rank 2, I don’t see any difference in behavior from rank 1 other than the “bump” in alignment that Lawrence mentioned. Here’s an example:
This doesn’t happen in all rank-2 cases but is relatively common. I think usually each vector grows primarily towards 1 or the other target. If two vectors grow towards the same target then you get this bump where one of them has to back off and align more towards a different target [at least that’s my current understanding, see my reply to Lawrence for more detail!].
What does a cross-entropy setup look like here? I’m just not sure how to map this toy model onto that loss (or vice-versa).
Agreed! I expect weight decay to (1) make the converged solution not actually minimize the original loss (because the weight decay keeps tugging it towards lower norms) and (2) accelerate the initial decay. I don’t think I expect any other changes.
I’m not sure! Do you have a setup in mind?
I agree this breaks my theoretical intuition. Experimentally most of the phenomenology is the same, except that the full-rank (rank 100) case regains a plateau.
Here’s rank 2:
rank 10:
(maybe there’s more ‘bump’ formation here than with SGD?)
rank 100:
It kind of looks like the plateau has returned! And this replicates across every rank 100 example I tried, e.g.
The plateau corresponds to a period with a lot of bump formation. If bumps really are a sign of vectors competing to represent different chunks of subspace then maybe this says that Adam produces more such competition (maybe by making different vectors learn at more similar rates?).
I’d be curious if you have any intuition about this!
I caution against over-interpreting the results of single runs—I think there’s a good chance the number of bumps varies significantly by random seed.
It’s a good caution, but I do see more bumps with Adam than with SGD across a number of random initializations.
(with the caveat that this is still “I tried a few times” and not any quantitative study)
There’s lots of ways to do this, but the obvious way is to flatten C and Z and treat them as logits.
Something like this?
Well, I’d keep everything in log space and do the whole thing with log_sum_exp for numerical stability, but yeah.
EDIT: e.g. something like:
Erm do C and Z have to be valid normalized probabilities for this to work?
C needs to be probabilities, yeah. Z can be any vector of numbers. (You can convert C into probabilities with softmax)
So indeed with cross-entropy loss I see two plateaus! Here’s rank 2:
(note that I’ve offset the loss to so that equality of Z and C is zero loss)
I have trouble getting rank 10 to find the zero-loss solution:
But the phenomenology at full rank is unchanged:
(Adam Jermyn ninja’ed my rank 2 results as I forgot to refresh, lol)
Weight decay just means the gradient becomes −∇xL=2(⟨b,y⟩a−⟨y,y⟩x)−λx, which effectively “extends” the exponential phase. It’s pretty easy to confirm that this is the case:
You can see the other figures from the main post here:
https://imgchest.com/p/9p4nl6vb7nq
(Lighter color shows loss curve for each of 10 random seeds.)
Here’s my code for the weight decay experiments if anyone wants to play with them or check that I didn’t mess something up: https://gist.github.com/Chanlaw/e8c286629e0626f723a20cef027665d1
Should be trivial to modify my code to use AdamW, just replace
SGD
withAdam
on line 33.EDIT: ran the experiments for rank 1, they seem a bit different than Adam Jermyn’s results—it looks like AdamW just accelerates things?
Woah, nice! Note that I didn’t check rank 1 with Adam, just rank >= 2.