transformer is only trained explicitly on next token prediction!
I find myself understanding language/multimodal transformer capabilities better when I think about the whole document (up to context length) as a mini-batch for calculating the gradient in transformer (pre-)training, so I imagine it is minimizing the document-global prediction error, it wasn’t trained to optimize for just a single-next token accuracy...
There is evidence that transformers are not in fact even implicitly, internally, optimized for reducing global prediction error (except insofar as comp-mech says they must in order to do well on the task they are optimized for).
Do transformers “think ahead” during inference at a given position? It is known transformers prepare information in the hidden states of the forward pass at t that is then used in future forward passes t+τ. We posit two explanations for this phenomenon: pre-caching, in which off-diagonal gradient terms present in training result in the model computing features at t irrelevant to the present inference task but useful for the future, and breadcrumbs, in which features most relevant to time step t are already the same as those that would most benefit inference at time t+τ. We test these hypotheses by training language models without propagating gradients to past timesteps, a scheme we formalize as myopic training. In a synthetic data setting, we find clear evidence for pre-caching. In the autoregressive language modeling setting, our experiments are more suggestive of the breadcrumbs hypothesis.
I think that paper is some evidence that there’s typically no huge effect from internal activations being optimized for predicting future tokens (on natural language). But I don’t think it’s much (if any) evidence that this doesn’t happen to some small extent or that it couldn’t be a huge effect on certain other natural language tasks.
(In fact, I think the myopia gap is probably the more relevant number than the local myopia bonus, in which case I’d argue the paper actually shows a pretty non-trivial effect, kind of contrary to how the authors interpret it. But I haven’t read the paper super closely.)
Also, sounds like you’re aware of this, but I’d want to highlight more that the paper does demonstrate internal activations being optimized for predicting future tokens on synthetic data where this is necessary. So, arguably, the main question is to what extent natural language data incentivizes this rather than being specifically about what transformers can/tend to do.
In that sense, thinking of transformer internals as “trying to” minimize the loss on an entire document might be exactly the right intuition empirically (and the question is mainly how different that is from being myopic on a given dataset). Given that the internal states are optimized for this, that would also make sense theoretically IMO.
+1 to this comment, also I expect the importance of activations being optimized for predicting future tokens to increase considerably with scale. (E.g., GPT-4 level compute maybe just gets you a GPT-3 level model if you enforce no such optimization with a stop grad.)
I have tried to play with Claude – I would ask it to think of a number, drop the hint, and only then print the number. It should have test the ability to have “hidden memory” that’s outside the text.
I expected it to be able to do that, but the hints to be too obvious. Instead, actually it failed multiple times in a row!
Sharing cause I liked the experiment but wasn’t sure if I executed it properly. There might be a way to do more of this.
P.S. I have also tried “print hash, and then preimage” – but this turned out to be even harder for him
That’s an interesting framing. From my perspective that is still just local next-token accuracy (cross-entropy more precisely), but averaged over all subsets of the data up to the context length. That is distinct from e.g. an objective function that explicitly mentioned not just next-token prediction, but multiple future tokens in what was needed to minimize loss. Does that distinction make sense?
One conceptual point I’d like to get across is that even though the equation for the predictive cross-entropy loss only has the next token at a given context window position in it, the states internal to the transformer have the information for predictions into the infinite future.
This is a slightly different issue than how one averages over training data, I think.
To me as a programmer and not a mathematitian, the distinction doesn’t make practical intuitive sense.
If we can create 3 functions f, g, h so that they “do the same thing” like f(a, b, c) == g(a)(b)(c) == average(h(a), h(b), h(c)), it seems to me that cross-entropy can “do the same thing” as some particular objective function that would explicitly mention multiple future tokens.
My intuition is that cross-entropy-powered “local accuracy” can approximate “global accuracy” well enough in practice that I should expect better global reasoning from larger model sizes, faster compute, algorithmic improvements, and better data.
Implications of this intuition might be:
myopia is a quantity not a quality, a model can be incentivized to be more or less myopic, but I don’t expect it will be proven possible to enforce it “in the limit”
instruct training on longer conversations outght to produce “better” overall conversations if the model simulates that it’s “in the middle” of a conversation and follow-up questions are better compared to giving a final answer “when close to the end of this kind of conversation”
What nuance should I consider to understand the distinction better?
I find myself understanding language/multimodal transformer capabilities better when I think about the whole document (up to context length) as a mini-batch for calculating the gradient in transformer (pre-)training, so I imagine it is minimizing the document-global prediction error, it wasn’t trained to optimize for just a single-next token accuracy...
There is evidence that transformers are not in fact even implicitly, internally, optimized for reducing global prediction error (except insofar as comp-mech says they must in order to do well on the task they are optimized for).
I think that paper is some evidence that there’s typically no huge effect from internal activations being optimized for predicting future tokens (on natural language). But I don’t think it’s much (if any) evidence that this doesn’t happen to some small extent or that it couldn’t be a huge effect on certain other natural language tasks.
(In fact, I think the myopia gap is probably the more relevant number than the local myopia bonus, in which case I’d argue the paper actually shows a pretty non-trivial effect, kind of contrary to how the authors interpret it. But I haven’t read the paper super closely.)
Also, sounds like you’re aware of this, but I’d want to highlight more that the paper does demonstrate internal activations being optimized for predicting future tokens on synthetic data where this is necessary. So, arguably, the main question is to what extent natural language data incentivizes this rather than being specifically about what transformers can/tend to do.
In that sense, thinking of transformer internals as “trying to” minimize the loss on an entire document might be exactly the right intuition empirically (and the question is mainly how different that is from being myopic on a given dataset). Given that the internal states are optimized for this, that would also make sense theoretically IMO.
+1 to this comment, also I expect the importance of activations being optimized for predicting future tokens to increase considerably with scale. (E.g., GPT-4 level compute maybe just gets you a GPT-3 level model if you enforce no such optimization with a stop grad.)
I have tried to play with Claude – I would ask it to think of a number, drop the hint, and only then print the number. It should have test the ability to have “hidden memory” that’s outside the text.
I expected it to be able to do that, but the hints to be too obvious. Instead, actually it failed multiple times in a row!
Sharing cause I liked the experiment but wasn’t sure if I executed it properly. There might be a way to do more of this.
P.S. I have also tried “print hash, and then preimage” – but this turned out to be even harder for him
Post the chat logs?
That’s an interesting framing. From my perspective that is still just local next-token accuracy (cross-entropy more precisely), but averaged over all subsets of the data up to the context length. That is distinct from e.g. an objective function that explicitly mentioned not just next-token prediction, but multiple future tokens in what was needed to minimize loss. Does that distinction make sense?
One conceptual point I’d like to get across is that even though the equation for the predictive cross-entropy loss only has the next token at a given context window position in it, the states internal to the transformer have the information for predictions into the infinite future.
This is a slightly different issue than how one averages over training data, I think.
To me as a programmer and not a mathematitian, the distinction doesn’t make practical intuitive sense.
If we can create 3 functions
f, g, h
so that they “do the same thing” likef(a, b, c) == g(a)(b)(c) == average(h(a), h(b), h(c))
, it seems to me that cross-entropy can “do the same thing” as some particular objective function that would explicitly mention multiple future tokens.My intuition is that cross-entropy-powered “local accuracy” can approximate “global accuracy” well enough in practice that I should expect better global reasoning from larger model sizes, faster compute, algorithmic improvements, and better data.
Implications of this intuition might be:
myopia is a quantity not a quality, a model can be incentivized to be more or less myopic, but I don’t expect it will be proven possible to enforce it “in the limit”
instruct training on longer conversations outght to produce “better” overall conversations if the model simulates that it’s “in the middle” of a conversation and follow-up questions are better compared to giving a final answer “when close to the end of this kind of conversation”
What nuance should I consider to understand the distinction better?