In our discussion of softmax (buried in part 1 of section 4), we argue that our story makes the most sense precisely when the temperature is very low, in which case we only attend to the key(s) that satisfy the most skip feature-bigrams. Also, when features are very sparse, the number of skip feature bigrams present in one query-key pair is almost always 0 or 1, and we aren’t trying to super precisely track whether its, say, 34 or 35.
I agree that if softmax is just being an argmax, then one implication is that we don’t need error terms to be o(1), instead, they can just be somewhat less than 1. However, at least in our general framework, this doesn’t help us beyond changing the log factor in the tilde inside ~Θ(dheaddresid). There still will be some log factor because we require the average error to be o(1) to prevent the worst-case error being greater than 1. Also, we may want to be able to accept ‘ties’ in which a small number >1 of token positions are attended to together. To achieve this (assuming that at most one SFB is present for each QK pair for simplicity) we’d want the variation in the values which should be 1 to be much smaller than the gap between the smallest value which should be 1 and the largest value which should be 0.
A few comments about your toy example:
To tell a general story, I’d like to replace the word ‘token’ with ‘feature’ in your construction. In particular, I might want to express what the attention head does using the same features as the MLP. The choice of using tokens in your example is special, because the set of features {this is token 1, this is token 2, …} are mutually exclusive, but once I allow for the possibility that multiple features can be present (for example if I want to talk in terms of features involved in MLP computation), your construction breaks. To avoid this problem, I want the maximum dot product between f-vectors to be at most 1/(the maximum number of features that can be present at once). If I allow several features to be present at once, this starts to look like an ϵ-orthogonal basis again. I guess you could imagine a case where the residual stream is divided into subspaces, and inside each subspace is a set of mutually exclusive features (à la tegum products of TMS). In your picture, there would need to be a 2d subspace allocated to the ‘which token’ features anyway. This tegum geometry would have to be specifically learned — these orthogonal subspaces do not happen generically, and we don’t see a good reason to think that they are likely to be learned by default for reasons not to do with the attention head that uses them, even in the case that there are these sets of mutually exclusive features.
It takes us more than 2 dimensions, but in our framework, it is possible to do a similar construction to yours in O(log(m)) dimensions assuming m random token vectors (ie without the need for any specific learned structure in the embeddings for this task): simply replace the rescaled projection matrix R=√dresiddheadPdhead with R′=√dresidnPn where n is O(log(m)) and Pn is a projection matrix to a n-dimensional subspace. Now, with high probability, each vector has a larger dot product with its own projection than another vector’s projection (we need n to be this large to ensure that projected vectors all have a similar length). Then use the same construction as in our post, and turn the softmax temperature down to zero.
Someone suggested this comment was inscrutable so here’s a summary:
I don’t think that how argmax-y softmax is being is a crux between us—we think our picture makes the most sense when softmax acts like argmax or top-k so we hope you’re right that softmax is argmax-ish. Instead, I think the property that enables your efficient solution is that the set of features ‘this token is token (i)’ is mutually exclusive, ie. only one of these features can activate on an input at once. That means that in your example you don’t have to worry about how to recover feature values when multiple features are present at once. For more general tasks implemented by an attention head, we do need to worry about what happens when multiple features are present at the same time, and then we need the f-vectors to form a nearly orthogonal basis and your construction becomes a special case of ours I think.
Thanks for the comment!
In more detail:
In our discussion of softmax (buried in part 1 of section 4), we argue that our story makes the most sense precisely when the temperature is very low, in which case we only attend to the key(s) that satisfy the most skip feature-bigrams. Also, when features are very sparse, the number of skip feature bigrams present in one query-key pair is almost always 0 or 1, and we aren’t trying to super precisely track whether its, say, 34 or 35.
I agree that if softmax is just being an argmax, then one implication is that we don’t need error terms to be o(1), instead, they can just be somewhat less than 1. However, at least in our general framework, this doesn’t help us beyond changing the log factor in the tilde inside ~Θ(dheaddresid). There still will be some log factor because we require the average error to be o(1) to prevent the worst-case error being greater than 1. Also, we may want to be able to accept ‘ties’ in which a small number >1 of token positions are attended to together. To achieve this (assuming that at most one SFB is present for each QK pair for simplicity) we’d want the variation in the values which should be 1 to be much smaller than the gap between the smallest value which should be 1 and the largest value which should be 0.
A few comments about your toy example:
To tell a general story, I’d like to replace the word ‘token’ with ‘feature’ in your construction. In particular, I might want to express what the attention head does using the same features as the MLP. The choice of using tokens in your example is special, because the set of features {this is token 1, this is token 2, …} are mutually exclusive, but once I allow for the possibility that multiple features can be present (for example if I want to talk in terms of features involved in MLP computation), your construction breaks. To avoid this problem, I want the maximum dot product between f-vectors to be at most 1/(the maximum number of features that can be present at once). If I allow several features to be present at once, this starts to look like an ϵ-orthogonal basis again. I guess you could imagine a case where the residual stream is divided into subspaces, and inside each subspace is a set of mutually exclusive features (à la tegum products of TMS). In your picture, there would need to be a 2d subspace allocated to the ‘which token’ features anyway. This tegum geometry would have to be specifically learned — these orthogonal subspaces do not happen generically, and we don’t see a good reason to think that they are likely to be learned by default for reasons not to do with the attention head that uses them, even in the case that there are these sets of mutually exclusive features.
It takes us more than 2 dimensions, but in our framework, it is possible to do a similar construction to yours in O(log(m)) dimensions assuming m random token vectors (ie without the need for any specific learned structure in the embeddings for this task): simply replace the rescaled projection matrix R=√dresiddheadPdhead with R′=√dresidnPn where n is O(log(m)) and Pn is a projection matrix to a n-dimensional subspace. Now, with high probability, each vector has a larger dot product with its own projection than another vector’s projection (we need n to be this large to ensure that projected vectors all have a similar length). Then use the same construction as in our post, and turn the softmax temperature down to zero.
Someone suggested this comment was inscrutable so here’s a summary:
I don’t think that how argmax-y softmax is being is a crux between us—we think our picture makes the most sense when softmax acts like argmax or top-k so we hope you’re right that softmax is argmax-ish. Instead, I think the property that enables your efficient solution is that the set of features ‘this token is token (i)’ is mutually exclusive, ie. only one of these features can activate on an input at once. That means that in your example you don’t have to worry about how to recover feature values when multiple features are present at once. For more general tasks implemented by an attention head, we do need to worry about what happens when multiple features are present at the same time, and then we need the f-vectors to form a nearly orthogonal basis and your construction becomes a special case of ours I think.