Thanks for thinking about this, I think this is an important topic!
Inside the AI’s chain-of-thought, each forward pass can generate many English tokens instead of one, allowing more information to pass through the bottleneck.
I wonder how one would do this; do you mean allow the model to output a distribution of tokens for each output position? (and then also read-in that distribution) I could imagine this being somewhere between normal CoT and latent (neuralese) CoT!
After the chain-of-thought ends, and the AI is giving its final answer, it generates only one English token at a time, to make each token higher quality. The architecture might still generate many tokens in one forward pass, but a simple filter repeatedly deletes everything except its first token from the context window.
If my interpretation of your idea above is correct then I imagine this part would look just like top-k / top-p generation like it is done currently, which seems sensible.
I’m only ~30% certain that I correctly understood your idea though so I’d love if you could clarify how this generating many tokens idea looks like!
Haha, my idea was just “maybe to solve this information bottleneck problem, we should solve the generate-many-tokens-in-one-pass problem.”
I haven’t really thought of any solution to the generate-many-tokens-in-one-pass problem yet :/
I’ll edit the post to mention this.
An attempt
One stupid attempt to solve the “generate-many-tokens-in-one-pass” problem, is to start off with the main LLM outputting 1 token at a time, and a small cheap LLM outputting the next 5 tokens. You then let the small LLM eavesdrop on the residual stream of the main LLM, and use reinforcement learning on both the main LLM and the small LLM.
The hope is that the main LLM will eventually learn to use part of its residual stream to communicate to the small LLM, and tell the small LLM what the next 5 tokens should be, so the computations in the main LLM can directly influence 6 tokens of output.
Slow thinking
I guess the “simple filter repeatedly deletes everything except its first token from the context window” was a bit unclear. I’ll rewrite that.
What I wanted to say was, when the AI is talking to you (rather than talking to itself in its chain-of-thought), we want the AI to slow down, and do more computation for each token it outputs. In this case, we don’t want it outputting many tokens for each forward pass. We want to only keep the first “high quality” token and delete the rest.
I don’t think this is related to top-k/top-p generation, because that’s referring to how an LLM samples one token from its distribution. k refers to the number of tokens considered not the number of tokens generated at once.
Thank you so much for reading and for the reply :)
Thanks for thinking about this, I think this is an important topic!
I wonder how one would do this; do you mean allow the model to output a distribution of tokens for each output position? (and then also read-in that distribution) I could imagine this being somewhere between normal CoT and latent (neuralese) CoT!
If my interpretation of your idea above is correct then I imagine this part would look just like top-k / top-p generation like it is done currently, which seems sensible.
I’m only ~30% certain that I correctly understood your idea though so I’d love if you could clarify how this generating many tokens idea looks like!
Haha, my idea was just “maybe to solve this information bottleneck problem, we should solve the generate-many-tokens-in-one-pass problem.”
I haven’t really thought of any solution to the generate-many-tokens-in-one-pass problem yet :/
I’ll edit the post to mention this.
An attempt
One stupid attempt to solve the “generate-many-tokens-in-one-pass” problem, is to start off with the main LLM outputting 1 token at a time, and a small cheap LLM outputting the next 5 tokens. You then let the small LLM eavesdrop on the residual stream of the main LLM, and use reinforcement learning on both the main LLM and the small LLM.
The hope is that the main LLM will eventually learn to use part of its residual stream to communicate to the small LLM, and tell the small LLM what the next 5 tokens should be, so the computations in the main LLM can directly influence 6 tokens of output.
Slow thinking
I guess the “simple filter repeatedly deletes everything except its first token from the context window” was a bit unclear. I’ll rewrite that.
What I wanted to say was, when the AI is talking to you (rather than talking to itself in its chain-of-thought), we want the AI to slow down, and do more computation for each token it outputs. In this case, we don’t want it outputting many tokens for each forward pass. We want to only keep the first “high quality” token and delete the rest.
I don’t think this is related to top-k/top-p generation, because that’s referring to how an LLM samples one token from its distribution. k refers to the number of tokens considered not the number of tokens generated at once.
Thank you so much for reading and for the reply :)