Based on my incomplete understanding of transformers:
A transformer does its computation on the entire sequence of tokens at once, and ends up predicting the next token for each token in the sequence.
At each layer, the attention mechanism gives the stream for each token the ability to look at the previous layer’s output for other token before it in the sequence.
The stream for each token doesn’t know if it’s the last in the sequence (and thus that its next-token prediction is the “main” prediction), or anything about the tokens that come after it.
So each token’s stream has two tasks in training: predict the next token, and generate the information that later tokens will use to predict their next tokens.
That information could take many different forms, but in some cases it could look like a “plan” (a prediction about the large-scale structure of the piece of writing that begins with the observed sequence so far from this token-stream’s point of view).
Based on my incomplete understanding of transformers:
A transformer does its computation on the entire sequence of tokens at once, and ends up predicting the next token for each token in the sequence.
At each layer, the attention mechanism gives the stream for each token the ability to look at the previous layer’s output for other token before it in the sequence.
The stream for each token doesn’t know if it’s the last in the sequence (and thus that its next-token prediction is the “main” prediction), or anything about the tokens that come after it.
So each token’s stream has two tasks in training: predict the next token, and generate the information that later tokens will use to predict their next tokens.
That information could take many different forms, but in some cases it could look like a “plan” (a prediction about the large-scale structure of the piece of writing that begins with the observed sequence so far from this token-stream’s point of view).