“bottleneck” of the CoT tokens. Whatever needs to be passed along from one sequential calculation step to the next must go through this bottleneck.
Does it, though? The keys and values from previous forward passes are still accessible, even if the generated token is not.
So the CoT tokens are not absolute information bottlenecks. But yes, replacing the token by a dot reduces the number of serial steps the model can perform (from mn to m+n, if there are m forward passes and n layers).
The “sequential calculation steps” I’m referring to are the ones that CoT adds above and beyond what can be done in a single forward pass. It’s the extra sequential computationadded by CoT, specifically, that is bottlenecked on the CoT tokens.
There is of course another notion of “sequential calculation steps” involved: the sequential layers of the model. However, I don’t think the bolded part of this is true:
replacing the token by a dot reduces the number of serial steps the model can perform (from mn to m+n, if there are m forward passes and n layers)
If a model with N layers has been trained to always produce exactly M “dot” tokens before answering, then the number of serial steps is just N, not M+N.
One way to see this is to note that we don’t actually need to run M separate forward passes. We can just pre-fill a context window containing the prompt tokens followed by M dot tokens, and run 1 forward pass on the whole thing.
Having the dots does add computation, but it’s only extra parallel computation – there’s still only one forward pass, just a “wider” one, with more computation happening in parallel inside each of the individually parallelizable steps (tensor multiplications, activation functions).
(If we relax the constraint that the number of dots is fixed, and allow the model to choose it based on the input, that still doesn’t add much: note that we could do 1 forward pass on the prompt tokens followed by a very large number of dots, then find the first position where we would have sampled a non-dot from from output distribution, truncate the KV cache to end at that point and sample normally from there.)
If you haven’t read the paper I linked in OP, I recommend it – it’s pretty illuminating about these distinctions. See e.g. the stuff about CoT making LMs more powerful than TC0 versus dots adding more power withing TC0.
Does it, though? The keys and values from previous forward passes are still accessible, even if the generated token is not.
So the CoT tokens are not absolute information bottlenecks. But yes, replacing the token by a dot reduces the number of serial steps the model can perform (from mn to m+n, if there are m forward passes and n layers).
The “sequential calculation steps” I’m referring to are the ones that CoT adds above and beyond what can be done in a single forward pass. It’s the extra sequential computation added by CoT, specifically, that is bottlenecked on the CoT tokens.
There is of course another notion of “sequential calculation steps” involved: the sequential layers of the model. However, I don’t think the bolded part of this is true:
If a model with N layers has been trained to always produce exactly M “dot” tokens before answering, then the number of serial steps is just N, not M+N.
One way to see this is to note that we don’t actually need to run M separate forward passes. We can just pre-fill a context window containing the prompt tokens followed by M dot tokens, and run 1 forward pass on the whole thing.
Having the dots does add computation, but it’s only extra parallel computation – there’s still only one forward pass, just a “wider” one, with more computation happening in parallel inside each of the individually parallelizable steps (tensor multiplications, activation functions).
(If we relax the constraint that the number of dots is fixed, and allow the model to choose it based on the input, that still doesn’t add much: note that we could do 1 forward pass on the prompt tokens followed by a very large number of dots, then find the first position where we would have sampled a non-dot from from output distribution, truncate the KV cache to end at that point and sample normally from there.)
If you haven’t read the paper I linked in OP, I recommend it – it’s pretty illuminating about these distinctions. See e.g. the stuff about CoT making LMs more powerful than TC0 versus dots adding more power withing TC0.