This is a really cool toy model, and also is consistent with Neel Nanda’s Modular Addition grokking work.
Do you know what’s up with the bump on the Inner Product w/Truth figures? The same bumps occur consistently for many metrics on several toy tasks, including in the Modular Addition grokking work.
I don’t, but here’s my best guess: there’s a sense in which there’s competition among vectors for which learned vectors capture which parts of the target span.
As a toy example, suppose there are two vectors, a1 and a2, such that the closest target vector to each of these at initialization is c. Then both vectors might grow towards c. At some point c is represented enough in the span, and it’s not optimal for two vectors to both play the role of representing c, so it becomes optimal for at least one of them to shift to cover other target vectors more.
For example, from a rank-4 case with a bump, here’s the inner product with a single target vector of two learned vectors:
So both vectors grow towards a single target, and the blue one starts realigning towards a different target as the orange one catches up.
Two more weak pieces of evidence in favor of this story:
We only ever see this bump when the rank is greater than 1.
From visual inspection, bumps are more likely to peak at higher levels of alignment than lower levels, and don’t happen at all in initial norm-decay phase, suggesting the bump is associated with vectors growing (rather than decaying).
Oh, huh, that makes a lot of sense! I’ll see if I can reproduce these results.
For example, from a rank-4 case with a bump, here’s the inner product with a single target vector of two learned vectors.
I’m not sure this explains the grokking bumps from the mod add stuff—I’m not sure what the should be “competition” should be given we see the bumps on every key frequency.
This is a really cool toy model, and also is consistent with Neel Nanda’s Modular Addition grokking work.
Do you know what’s up with the bump on the Inner Product w/Truth figures? The same bumps occur consistently for many metrics on several toy tasks, including in the Modular Addition grokking work.
EDIT: if anyone wants to play with the results in this paper, here’s a gist I whipped up:
https://gist.github.com/Chanlaw/e8c286629e0626f723a20cef027665d1
I don’t, but here’s my best guess: there’s a sense in which there’s competition among vectors for which learned vectors capture which parts of the target span.
As a toy example, suppose there are two vectors, a1 and a2, such that the closest target vector to each of these at initialization is c. Then both vectors might grow towards c. At some point c is represented enough in the span, and it’s not optimal for two vectors to both play the role of representing c, so it becomes optimal for at least one of them to shift to cover other target vectors more.
For example, from a rank-4 case with a bump, here’s the inner product with a single target vector of two learned vectors:
So both vectors grow towards a single target, and the blue one starts realigning towards a different target as the orange one catches up.
Two more weak pieces of evidence in favor of this story:
We only ever see this bump when the rank is greater than 1.
From visual inspection, bumps are more likely to peak at higher levels of alignment than lower levels, and don’t happen at all in initial norm-decay phase, suggesting the bump is associated with vectors growing (rather than decaying).
Oh, huh, that makes a lot of sense! I’ll see if I can reproduce these results.
I’m not sure this explains the grokking bumps from the mod add stuff—I’m not sure what the should be “competition” should be given we see the bumps on every key frequency.
I’d be very excited to see a reproduction :-)