I think that the QK section somewhat under-emphasises the importance of the softmax. My intuition is that models rarely care about as precise a task as counting the number of pairs of matching query-key features at each pair of token positions, and that instead softmax is more of an “argmax-like” function that finds a handful of important token positions (though I have not empirically tested this, and would love to be proven wrong!). This enables much cheaper and more efficient solutions, since you just need the correct answer to be the argmax-ish.
For example, ignoring floating point precision, you can implement a duplicate token head with dhead=2 and arbitrarily high dvocab. If there are n vocab elements, map the mth query and key to the point mn of the way round the unit circle. The dot product is maximised when they are equal.
If you further want the head to look at a resting position unless the duplicate token is there, you can increase dhead=3, and have a dedicated BOS dimension with a score of 1−ϵ, so you only get a higher score for a perfect match. And then make the softmax temperature super low so it’s an argmax.
In our discussion of softmax (buried in part 1 of section 4), we argue that our story makes the most sense precisely when the temperature is very low, in which case we only attend to the key(s) that satisfy the most skip feature-bigrams. Also, when features are very sparse, the number of skip feature bigrams present in one query-key pair is almost always 0 or 1, and we aren’t trying to super precisely track whether its, say, 34 or 35.
I agree that if softmax is just being an argmax, then one implication is that we don’t need error terms to be o(1), instead, they can just be somewhat less than 1. However, at least in our general framework, this doesn’t help us beyond changing the log factor in the tilde inside ~Θ(dheaddresid). There still will be some log factor because we require the average error to be o(1) to prevent the worst-case error being greater than 1. Also, we may want to be able to accept ‘ties’ in which a small number >1 of token positions are attended to together. To achieve this (assuming that at most one SFB is present for each QK pair for simplicity) we’d want the variation in the values which should be 1 to be much smaller than the gap between the smallest value which should be 1 and the largest value which should be 0.
A few comments about your toy example:
To tell a general story, I’d like to replace the word ‘token’ with ‘feature’ in your construction. In particular, I might want to express what the attention head does using the same features as the MLP. The choice of using tokens in your example is special, because the set of features {this is token 1, this is token 2, …} are mutually exclusive, but once I allow for the possibility that multiple features can be present (for example if I want to talk in terms of features involved in MLP computation), your construction breaks. To avoid this problem, I want the maximum dot product between f-vectors to be at most 1/(the maximum number of features that can be present at once). If I allow several features to be present at once, this starts to look like an ϵ-orthogonal basis again. I guess you could imagine a case where the residual stream is divided into subspaces, and inside each subspace is a set of mutually exclusive features (à la tegum products of TMS). In your picture, there would need to be a 2d subspace allocated to the ‘which token’ features anyway. This tegum geometry would have to be specifically learned — these orthogonal subspaces do not happen generically, and we don’t see a good reason to think that they are likely to be learned by default for reasons not to do with the attention head that uses them, even in the case that there are these sets of mutually exclusive features.
It takes us more than 2 dimensions, but in our framework, it is possible to do a similar construction to yours in O(log(m)) dimensions assuming m random token vectors (ie without the need for any specific learned structure in the embeddings for this task): simply replace the rescaled projection matrix R=√dresiddheadPdhead with R′=√dresidnPn where n is O(log(m)) and Pn is a projection matrix to a n-dimensional subspace. Now, with high probability, each vector has a larger dot product with its own projection than another vector’s projection (we need n to be this large to ensure that projected vectors all have a similar length). Then use the same construction as in our post, and turn the softmax temperature down to zero.
Someone suggested this comment was inscrutable so here’s a summary:
I don’t think that how argmax-y softmax is being is a crux between us—we think our picture makes the most sense when softmax acts like argmax or top-k so we hope you’re right that softmax is argmax-ish. Instead, I think the property that enables your efficient solution is that the set of features ‘this token is token (i)’ is mutually exclusive, ie. only one of these features can activate on an input at once. That means that in your example you don’t have to worry about how to recover feature values when multiple features are present at once. For more general tasks implemented by an attention head, we do need to worry about what happens when multiple features are present at the same time, and then we need the f-vectors to form a nearly orthogonal basis and your construction becomes a special case of ours I think.
Interesting post, thanks for writing it!
I think that the QK section somewhat under-emphasises the importance of the softmax. My intuition is that models rarely care about as precise a task as counting the number of pairs of matching query-key features at each pair of token positions, and that instead softmax is more of an “argmax-like” function that finds a handful of important token positions (though I have not empirically tested this, and would love to be proven wrong!). This enables much cheaper and more efficient solutions, since you just need the correct answer to be the argmax-ish.
For example, ignoring floating point precision, you can implement a duplicate token head with dhead=2 and arbitrarily high dvocab. If there are n vocab elements, map the mth query and key to the point mn of the way round the unit circle. The dot product is maximised when they are equal.
If you further want the head to look at a resting position unless the duplicate token is there, you can increase dhead=3, and have a dedicated BOS dimension with a score of 1−ϵ, so you only get a higher score for a perfect match. And then make the softmax temperature super low so it’s an argmax.
Thanks for the comment!
In more detail:
In our discussion of softmax (buried in part 1 of section 4), we argue that our story makes the most sense precisely when the temperature is very low, in which case we only attend to the key(s) that satisfy the most skip feature-bigrams. Also, when features are very sparse, the number of skip feature bigrams present in one query-key pair is almost always 0 or 1, and we aren’t trying to super precisely track whether its, say, 34 or 35.
I agree that if softmax is just being an argmax, then one implication is that we don’t need error terms to be o(1), instead, they can just be somewhat less than 1. However, at least in our general framework, this doesn’t help us beyond changing the log factor in the tilde inside ~Θ(dheaddresid). There still will be some log factor because we require the average error to be o(1) to prevent the worst-case error being greater than 1. Also, we may want to be able to accept ‘ties’ in which a small number >1 of token positions are attended to together. To achieve this (assuming that at most one SFB is present for each QK pair for simplicity) we’d want the variation in the values which should be 1 to be much smaller than the gap between the smallest value which should be 1 and the largest value which should be 0.
A few comments about your toy example:
To tell a general story, I’d like to replace the word ‘token’ with ‘feature’ in your construction. In particular, I might want to express what the attention head does using the same features as the MLP. The choice of using tokens in your example is special, because the set of features {this is token 1, this is token 2, …} are mutually exclusive, but once I allow for the possibility that multiple features can be present (for example if I want to talk in terms of features involved in MLP computation), your construction breaks. To avoid this problem, I want the maximum dot product between f-vectors to be at most 1/(the maximum number of features that can be present at once). If I allow several features to be present at once, this starts to look like an ϵ-orthogonal basis again. I guess you could imagine a case where the residual stream is divided into subspaces, and inside each subspace is a set of mutually exclusive features (à la tegum products of TMS). In your picture, there would need to be a 2d subspace allocated to the ‘which token’ features anyway. This tegum geometry would have to be specifically learned — these orthogonal subspaces do not happen generically, and we don’t see a good reason to think that they are likely to be learned by default for reasons not to do with the attention head that uses them, even in the case that there are these sets of mutually exclusive features.
It takes us more than 2 dimensions, but in our framework, it is possible to do a similar construction to yours in O(log(m)) dimensions assuming m random token vectors (ie without the need for any specific learned structure in the embeddings for this task): simply replace the rescaled projection matrix R=√dresiddheadPdhead with R′=√dresidnPn where n is O(log(m)) and Pn is a projection matrix to a n-dimensional subspace. Now, with high probability, each vector has a larger dot product with its own projection than another vector’s projection (we need n to be this large to ensure that projected vectors all have a similar length). Then use the same construction as in our post, and turn the softmax temperature down to zero.
Someone suggested this comment was inscrutable so here’s a summary:
I don’t think that how argmax-y softmax is being is a crux between us—we think our picture makes the most sense when softmax acts like argmax or top-k so we hope you’re right that softmax is argmax-ish. Instead, I think the property that enables your efficient solution is that the set of features ‘this token is token (i)’ is mutually exclusive, ie. only one of these features can activate on an input at once. That means that in your example you don’t have to worry about how to recover feature values when multiple features are present at once. For more general tasks implemented by an attention head, we do need to worry about what happens when multiple features are present at the same time, and then we need the f-vectors to form a nearly orthogonal basis and your construction becomes a special case of ours I think.