Interesting, but not (I think?) the direction I was headed in.
I was thinking more about the way the model seems to be managing a tradeoff between preserving the representation of token i and producing the representation of token i+1.
The depth-wise continuity imposed by weight decay means late layers are representing something close to the final output—in late layers the model is roughly looking at its own guesses, even if they were wrong, which seems suboptimal.
Consider this scenario:
The model does poorly at position i, assigning very low probability to the true token residing at i+1.
To retain a clear view of the input sequence, the model now needs to “keep around” the true token at i+1, since its own guess is a poor proxy.
But early layers don’t know that: they can’t “look up” and notice the poor prediction. So they just treat i+1 like any other position. (I.e. there’s no way to implement a selective “copy when we got it wrong” mechanism)
In late layers, position i+1 has been converted into a guess about i+2 by the earlier layers, so we can’t rely on it to tell us what really occupied i+1.
And position i has been converted to a bad guess about position i+1, so if we use it as a proxy for i+1 we’ll do poorly.
My sampling idea was something like “let’s replace (or interpolate) late activations with embeddings of the actual next token, so the model can see what really happened, even when its probability was low.” (This is for sampling specifically because it’d be too slow in training, where you want to process a whole window at once with matrix operations; sampling has to be a loop anyway, so there’s no cost to adding stuff that only works as a loop.)
But, thinking about it more, the model clearly can perform well in scenarios like the above, e.g. my plasma example and also many other cases naturally arising in language which GPT handles well.
I have no idea how it does it—indeed the connection structure feels weirdly adverse to such operations—but apparently it does. So it’s probably premature to assume it can’t do this well, and attempt to “help it out” with extra tricks.
Interesting, but not (I think?) the direction I was headed in.
I was thinking more about the way the model seems to be managing a tradeoff between preserving the representation of token i and producing the representation of token i+1.
The depth-wise continuity imposed by weight decay means late layers are representing something close to the final output—in late layers the model is roughly looking at its own guesses, even if they were wrong, which seems suboptimal.
Consider this scenario:
The model does poorly at position i, assigning very low probability to the true token residing at i+1.
To retain a clear view of the input sequence, the model now needs to “keep around” the true token at i+1, since its own guess is a poor proxy.
But early layers don’t know that: they can’t “look up” and notice the poor prediction. So they just treat i+1 like any other position. (I.e. there’s no way to implement a selective “copy when we got it wrong” mechanism)
In late layers, position i+1 has been converted into a guess about i+2 by the earlier layers, so we can’t rely on it to tell us what really occupied i+1.
And position i has been converted to a bad guess about position i+1, so if we use it as a proxy for i+1 we’ll do poorly.
My sampling idea was something like “let’s replace (or interpolate) late activations with embeddings of the actual next token, so the model can see what really happened, even when its probability was low.” (This is for sampling specifically because it’d be too slow in training, where you want to process a whole window at once with matrix operations; sampling has to be a loop anyway, so there’s no cost to adding stuff that only works as a loop.)
But, thinking about it more, the model clearly can perform well in scenarios like the above, e.g. my plasma example and also many other cases naturally arising in language which GPT handles well.
I have no idea how it does it—indeed the connection structure feels weirdly adverse to such operations—but apparently it does. So it’s probably premature to assume it can’t do this well, and attempt to “help it out” with extra tricks.