In a transformer, the compute cost for context length n grows at O(n^2)[4], so it’s a 16x increase in compute cost to go from 2000 tokens to 8000, and another 16x increase to go to 32000. To the best of my knowledge, there isn’t much additional cost to a longer context window—the number of parameters to encode more positions is very small for a model this big.
I do not understand this paragraph, it seems like the first sentence contradicts the second.
Edit: I think I understand. Are you saying there isn’t much additional cost on top of the cost mentioned in the previous sentence because the position encoding is tiny compared to everything else in the model?
It’s a bit like saying the extra cost of maintenance on a big SUV is small compared to a commodity car. The extra cost of fuel is still big, but the extra cost of maintenance isn’t much.
It’s ordinary computational complexity reasoning: if a part of your program scales like n^2, and another like n, then for large enough n the former will overtake the latter and pretty much dominate the total cost. That said, as someone pointed out, the specifics matter too. If your total cost was something like n^2+1,000,000,000n, it would take a very big n for the quadratic term to finally make itself felt properly. So depending on the details of the architecture, and how it was scaled up in ways other than just increasing context, the scaling might not actually look very quadratic at all.
I do not understand this paragraph, it seems like the first sentence contradicts the second.
Edit: I think I understand. Are you saying there isn’t much additional cost on top of the cost mentioned in the previous sentence because the position encoding is tiny compared to everything else in the model?
Yep, exactly as you explain in your edit!
It’s a bit like saying the extra cost of maintenance on a big SUV is small compared to a commodity car. The extra cost of fuel is still big, but the extra cost of maintenance isn’t much.
It’s ordinary computational complexity reasoning: if a part of your program scales like n^2, and another like n, then for large enough n the former will overtake the latter and pretty much dominate the total cost. That said, as someone pointed out, the specifics matter too. If your total cost was something like n^2+1,000,000,000n, it would take a very big n for the quadratic term to finally make itself felt properly. So depending on the details of the architecture, and how it was scaled up in ways other than just increasing context, the scaling might not actually look very quadratic at all.