The main question I have been thinking about is what is a state for language and how that can be useful if so discovered in this way?
The way I would put it is that ‘state’ is misleading you here. It makes you think that it must be some sort of little Turing machine or clockwork, where it has a ‘state’, like the current state of the Turing machine tape or the rotations of each gear in a clockwork gadget, where the goal is to infer that. This is misleading, and it is a coincidence in these simple toy problems, which are so simple that there is nothing to know beyond the actual state.
As Ortega et al highlights in those graphs, what you are really trying to define is the sufficient statistics: the summary of the data (history) which is 100% adequate for decision making, and where additionally knowing the original raw data doesn’t help you.
In the coin flip case, the sufficient statistics are simply the 2-tuple (heads,tails), and you define a very simple decision over all of the possible observed 2-tuples. Note that the sufficient statistic is less information than the original raw “the history”, because you throw out the ordering. (A 2-tuple like ‘(3,1)’ is simpler than all of the histories it summarizes, like ‘[1,1,1,0]‘, ‘[0,1,1,1]‘. ‘[1,0,1,1]’, etc.) From the point of view of decision making, these all yield the same posterior distribution over the coin flip probability parameter, which is all you need for decision making (optimal action: ‘bet on the side with the higher probability’), and so that’s the sufficient statistic. If I tell you the history as a list instead of a 2-tuple, you cannot make better decisions. It just doesn’t matter if you got a tails first and then all heads, or all heads first then tails, etc.
It is not obvious that this is true: a priori, maybe that ordering was hugely important, and those correspond to different games. But the RNN there has learned that the differences are not important, and in fact, they are all the same.
And the 2-tuple here doesn’t correspond to any particular environment ‘state’. The environment doesn’t need to store that anywhere. The environment is just a RNG operating according to the coin flip probability, independently every turn of the game, with no memory. There is nowhere which is counting heads & tails in a 2-tuple. That exists solely in the RNN’s hidden state as it accumulates evidence over turns, and optimally updates priors to posteriors every observed coin flip, and possibly switches its bet.
So, in language tasks like LLMs, they are the same thing, but on a vastly grander scale, and still highly incomplete. They are (trying to) infer sufficient statistics of whatever language-games they have been trained on, and then predicting accordingly.
What are those sufficient statistics in LLMs? Hard to say. In that coinflip example, it is so simple that we can easily derive by hand the conjugate statistics and know it is just a binomial and so we only need to track heads/tails as the one and only sufficient statistic, and we can then look in the hidden state to find where that is encoded in a converged optimal agent. In LLMs… not so much. There’s a lot going on.
Based on interpretability research and studies of how well they simulate people as well as just all of the anecdotal experience with the base models, we can point to a few latents like honesty, calibration, demographics, and so on. (See Janus’s “Simulator Theory” for a more poetic take, less focused on agency than the straight Bayesian meta-imitation learning take I’m giving here.) Meanwhile, there are tons of things about the inputs that the model wants to throw away, irrelevant details like the exact mispellings of words in the prompt (while recording that there were mispellings, as grist for the inference mill about the environment generating the mispelled text).
So conceptually, the sufficient statistics when you or I punch in a prompt to GPT-3 might look like some extremely long list of variables like, “English speaker, Millennial, American, telling the truth, reliable, above-average intelligence, Common Crawl-like text not corrupted by WET processing, shortform, Markdown formatting, only 1 common typo or misspelling total, …” and it will then tailor responses accordingly and maximize its utility by predicting the next token accurately (because the ‘coin flip’ there is simply betting on the logits with the highest likelihood etc). Like the coinflip 2-tuple, most of these do not correspond to any real-world ‘state’: if you or I put in a prompt, there is no atom or set of atoms which corresponds to many of these variables. But they have consequences: if we ask about Tienanmen Square, for example, we’ll get a different answer than if we had asked in Mandarin, because the sufficient statistics there are inferred to be very different and yield a different array of latents which cause different outputs.
And that’s what “state” is for language: it is the model’s best attempt to infer a useful set of latent variables which collectively are sufficient statistics for whatever language-game or task or environment or agent-history or whatever the context/prompt encodes, which then supports optimal decision-making.
My earlier comment on meta-learning and Bayesian RL/inference for background: https://www.lesswrong.com/posts/TiBsZ9beNqDHEvXt4/how-we-picture-bayesian-agents?commentId=yhmoEbztTunQMRzJx
The way I would put it is that ‘state’ is misleading you here. It makes you think that it must be some sort of little Turing machine or clockwork, where it has a ‘state’, like the current state of the Turing machine tape or the rotations of each gear in a clockwork gadget, where the goal is to infer that. This is misleading, and it is a coincidence in these simple toy problems, which are so simple that there is nothing to know beyond the actual state.
As Ortega et al highlights in those graphs, what you are really trying to define is the sufficient statistics: the summary of the data (history) which is 100% adequate for decision making, and where additionally knowing the original raw data doesn’t help you.
In the coin flip case, the sufficient statistics are simply the 2-tuple (heads,tails), and you define a very simple decision over all of the possible observed 2-tuples. Note that the sufficient statistic is less information than the original raw “the history”, because you throw out the ordering. (A 2-tuple like ‘(3,1)’ is simpler than all of the histories it summarizes, like ‘[1,1,1,0]‘, ‘[0,1,1,1]‘. ‘[1,0,1,1]’, etc.) From the point of view of decision making, these all yield the same posterior distribution over the coin flip probability parameter, which is all you need for decision making (optimal action: ‘bet on the side with the higher probability’), and so that’s the sufficient statistic. If I tell you the history as a list instead of a 2-tuple, you cannot make better decisions. It just doesn’t matter if you got a tails first and then all heads, or all heads first then tails, etc.
It is not obvious that this is true: a priori, maybe that ordering was hugely important, and those correspond to different games. But the RNN there has learned that the differences are not important, and in fact, they are all the same.
And the 2-tuple here doesn’t correspond to any particular environment ‘state’. The environment doesn’t need to store that anywhere. The environment is just a RNG operating according to the coin flip probability, independently every turn of the game, with no memory. There is nowhere which is counting heads & tails in a 2-tuple. That exists solely in the RNN’s hidden state as it accumulates evidence over turns, and optimally updates priors to posteriors every observed coin flip, and possibly switches its bet.
So, in language tasks like LLMs, they are the same thing, but on a vastly grander scale, and still highly incomplete. They are (trying to) infer sufficient statistics of whatever language-games they have been trained on, and then predicting accordingly.
What are those sufficient statistics in LLMs? Hard to say. In that coinflip example, it is so simple that we can easily derive by hand the conjugate statistics and know it is just a binomial and so we only need to track heads/tails as the one and only sufficient statistic, and we can then look in the hidden state to find where that is encoded in a converged optimal agent. In LLMs… not so much. There’s a lot going on.
Based on interpretability research and studies of how well they simulate people as well as just all of the anecdotal experience with the base models, we can point to a few latents like honesty, calibration, demographics, and so on. (See Janus’s “Simulator Theory” for a more poetic take, less focused on agency than the straight Bayesian meta-imitation learning take I’m giving here.) Meanwhile, there are tons of things about the inputs that the model wants to throw away, irrelevant details like the exact mispellings of words in the prompt (while recording that there were mispellings, as grist for the inference mill about the environment generating the mispelled text).
So conceptually, the sufficient statistics when you or I punch in a prompt to GPT-3 might look like some extremely long list of variables like, “English speaker, Millennial, American, telling the truth, reliable, above-average intelligence, Common Crawl-like text not corrupted by WET processing, shortform, Markdown formatting, only 1 common typo or misspelling total, …” and it will then tailor responses accordingly and maximize its utility by predicting the next token accurately (because the ‘coin flip’ there is simply betting on the logits with the highest likelihood etc). Like the coinflip 2-tuple, most of these do not correspond to any real-world ‘state’: if you or I put in a prompt, there is no atom or set of atoms which corresponds to many of these variables. But they have consequences: if we ask about Tienanmen Square, for example, we’ll get a different answer than if we had asked in Mandarin, because the sufficient statistics there are inferred to be very different and yield a different array of latents which cause different outputs.
And that’s what “state” is for language: it is the model’s best attempt to infer a useful set of latent variables which collectively are sufficient statistics for whatever language-game or task or environment or agent-history or whatever the context/prompt encodes, which then supports optimal decision-making.