In transformers the compute cost for context length n of a part of the attention mechanism, which itself is only a part of the transformer architecture, grows at O(n^2), so for the transformer itself this is only true in the limit.
This is true, and a useful corrective. I’ll edit the post to make this clear.
In fact, I think that as models are scaled, the attention mechanism becomes an ever smaller part of the overall compute cost (empirically, i.e. I saw a table to that effect, you could certainly scale differently), so with model scaling you get more and more leeway to increase the context length without impacting compute (both training and inference) cost too much.
I’d love to learn more about this, do you remember where you saw that table?
It was posted by Stella Biderman probably in the Eleuther discord many month ago. I took a screenshot that I somewhat surprisingly was able to find again just now. At least for OPT at GPT-3 size the feedforward layers apparently use 80% of the compute. The naming of the other columns confuses me, logically one should be a key, query, value computation and the other the actually quadratic dot-product, but I can’t tell which is which.
Apparently all OPT models were trained with a 2k token context length. So based on this, assuming basic O(n^2) scaling, an 8k token version of the 175B model would have the attention stage scale to about 35% of the FLOPS, and a 32k token version would scale to almost 90% of the FLOPS. 8k tokens is somewhat excusable, but 32k tokens is still overwhelmingly significant even with a 175B parameter model, costing around 840% more compute than a 2k token model. That percentage will probably only drop to a reasonable level at around the 10T parameter model level, provided O(n^2) scaling at least. And that’s all assuming the other aspects of the model don’t scale at all with the larger context length… A new approach is definitely going to be needed soon. Maybe H3?
This is true, and a useful corrective. I’ll edit the post to make this clear.
I’d love to learn more about this, do you remember where you saw that table?
It was posted by Stella Biderman probably in the Eleuther discord many month ago. I took a screenshot that I somewhat surprisingly was able to find again just now. At least for OPT at GPT-3 size the feedforward layers apparently use 80% of the compute. The naming of the other columns confuses me, logically one should be a key, query, value computation and the other the actually quadratic dot-product, but I can’t tell which is which.
Apparently all OPT models were trained with a 2k token context length. So based on this, assuming basic O(n^2) scaling, an 8k token version of the 175B model would have the attention stage scale to about 35% of the FLOPS, and a 32k token version would scale to almost 90% of the FLOPS. 8k tokens is somewhat excusable, but 32k tokens is still overwhelmingly significant even with a 175B parameter model, costing around 840% more compute than a 2k token model. That percentage will probably only drop to a reasonable level at around the 10T parameter model level, provided O(n^2) scaling at least. And that’s all assuming the other aspects of the model don’t scale at all with the larger context length… A new approach is definitely going to be needed soon. Maybe H3?