I think this might suggest there is some fundamentally better way to do sampling from GPT models? I’m having trouble writing out the intuition clearly, so I’ll leave it for later posts.
Unroll the sampling process: hook up all the individual GPT instances into a single long model, bypass the discretizing/embedding layers to make it differentiable end-to-end, and do gradient ascent to find the sequence which maximizes likelihood conditional on the fixed input.
Interesting, but not (I think?) the direction I was headed in.
I was thinking more about the way the model seems to be managing a tradeoff between preserving the representation of token i and producing the representation of token i+1.
The depth-wise continuity imposed by weight decay means late layers are representing something close to the final output—in late layers the model is roughly looking at its own guesses, even if they were wrong, which seems suboptimal.
Consider this scenario:
The model does poorly at position i, assigning very low probability to the true token residing at i+1.
To retain a clear view of the input sequence, the model now needs to “keep around” the true token at i+1, since its own guess is a poor proxy.
But early layers don’t know that: they can’t “look up” and notice the poor prediction. So they just treat i+1 like any other position. (I.e. there’s no way to implement a selective “copy when we got it wrong” mechanism)
In late layers, position i+1 has been converted into a guess about i+2 by the earlier layers, so we can’t rely on it to tell us what really occupied i+1.
And position i has been converted to a bad guess about position i+1, so if we use it as a proxy for i+1 we’ll do poorly.
My sampling idea was something like “let’s replace (or interpolate) late activations with embeddings of the actual next token, so the model can see what really happened, even when its probability was low.” (This is for sampling specifically because it’d be too slow in training, where you want to process a whole window at once with matrix operations; sampling has to be a loop anyway, so there’s no cost to adding stuff that only works as a loop.)
But, thinking about it more, the model clearly can perform well in scenarios like the above, e.g. my plasma example and also many other cases naturally arising in language which GPT handles well.
I have no idea how it does it—indeed the connection structure feels weirdly adverse to such operations—but apparently it does. So it’s probably premature to assume it can’t do this well, and attempt to “help it out” with extra tricks.
It doesn’t sound hard at all. The things Gwern is describing are the same sort of thing that people do for interpretability where they, eg, find an image that maximizes the probability of the network predicting a target class.
Of course, you need access to the model, so only OpenAI could do it for GPT-3 right now.
Doing it with GPT-3 would be quite challenging just for compute requirements like RAM. You’d want to test this out on GPT-2-117M first, definitely. If the approach works at all, it should work well for the smallest models too.
Unroll the sampling process: hook up all the individual GPT instances into a single long model, bypass the discretizing/embedding layers to make it differentiable end-to-end, and do gradient ascent to find the sequence which maximizes likelihood conditional on the fixed input.
Interesting, but not (I think?) the direction I was headed in.
I was thinking more about the way the model seems to be managing a tradeoff between preserving the representation of token i and producing the representation of token i+1.
The depth-wise continuity imposed by weight decay means late layers are representing something close to the final output—in late layers the model is roughly looking at its own guesses, even if they were wrong, which seems suboptimal.
Consider this scenario:
The model does poorly at position i, assigning very low probability to the true token residing at i+1.
To retain a clear view of the input sequence, the model now needs to “keep around” the true token at i+1, since its own guess is a poor proxy.
But early layers don’t know that: they can’t “look up” and notice the poor prediction. So they just treat i+1 like any other position. (I.e. there’s no way to implement a selective “copy when we got it wrong” mechanism)
In late layers, position i+1 has been converted into a guess about i+2 by the earlier layers, so we can’t rely on it to tell us what really occupied i+1.
And position i has been converted to a bad guess about position i+1, so if we use it as a proxy for i+1 we’ll do poorly.
My sampling idea was something like “let’s replace (or interpolate) late activations with embeddings of the actual next token, so the model can see what really happened, even when its probability was low.” (This is for sampling specifically because it’d be too slow in training, where you want to process a whole window at once with matrix operations; sampling has to be a loop anyway, so there’s no cost to adding stuff that only works as a loop.)
But, thinking about it more, the model clearly can perform well in scenarios like the above, e.g. my plasma example and also many other cases naturally arising in language which GPT handles well.
I have no idea how it does it—indeed the connection structure feels weirdly adverse to such operations—but apparently it does. So it’s probably premature to assume it can’t do this well, and attempt to “help it out” with extra tricks.
How far away is this from being implementable?
It doesn’t sound hard at all. The things Gwern is describing are the same sort of thing that people do for interpretability where they, eg, find an image that maximizes the probability of the network predicting a target class.
Of course, you need access to the model, so only OpenAI could do it for GPT-3 right now.
Doing it with GPT-3 would be quite challenging just for compute requirements like RAM. You’d want to test this out on GPT-2-117M first, definitely. If the approach works at all, it should work well for the smallest models too.