Thanks for this post! I’m not sure how much I expect this to matter in practice, but I think that the underlying point of “sometimes the data distribution matters a lot, and ignoring it is suspect” seems sound and well made.
I personal think it’s clear that 1L attn-only models are not literally just doing skip trigrams. A quick brainstorm of other things I presume they’re doing:
Skip trigrams with positional decay—it’s easy enough to add a negative term to the attention scores that gets bigger the further away the source token is. For skip trigrams like “keep … in → mind” that clearly want to be trigrams, it seems like this has to be going on. You can mediate it with a high BOS attention, so it doesn’t attend locally when no skip trigram fires
Hierarchical skip trigrams (somewhat like your example) - if skip trigram 1 triggers, it stops skip trigram 2 from triggering.
Dealing with the level of saturation—for a trigram A … B → C, it’s unclear what happens if there are multiple copies of A, and the model can choose how to mediate this.
Toy example—let’s say attention to A and to BOS are all that matters (everything else is −1000). If BOS is 0 and A is 10, then the model doesn’t care if there are multiple As, it saturates immediately. If BOS is 5 and A is 0, then it’s now basically linear in the number of As. (In practice I expect it’ll be somewhere in the middle)
“context detection”—text tends to cluster into different contexts (books, wikipedia, etc), which have very different unigram and bigram statistics, and a model could learn to do the correct unigram updates for any destination token, conditional on a bunch of relevant source tokens being there.
There are probably some tokens that are way more common in some contexts than others (” yield”, ” return”, “\n\t\t” in Python code, etc), and a model could learn skip trigrams from any dest token to these source tokens (probably implemented via the query bias), whose OV circuit just boosts the unigram direction.
In some sense this is an ensemble of skip trigrams, I think it’s different because the natural way to implement it is to have the total attention paid to the special tokens saturate at uniform beyond a certain threshold of possible source tokens
Skip bigrams—when there’s a previous token head that acts semi-independently of the destination token, to implement A __ → C behaviour. I expect this arises for eg can|‘t| vs don|’t|, where the previous token significantly disambiguates the current token.
Thanks for this post! I’m not sure how much I expect this to matter in practice, but I think that the underlying point of “sometimes the data distribution matters a lot, and ignoring it is suspect” seems sound and well made.
I personal think it’s clear that 1L attn-only models are not literally just doing skip trigrams. A quick brainstorm of other things I presume they’re doing:
Skip trigrams with positional decay—it’s easy enough to add a negative term to the attention scores that gets bigger the further away the source token is. For skip trigrams like “keep … in → mind” that clearly want to be trigrams, it seems like this has to be going on. You can mediate it with a high BOS attention, so it doesn’t attend locally when no skip trigram fires
Hierarchical skip trigrams (somewhat like your example) - if skip trigram 1 triggers, it stops skip trigram 2 from triggering.
Dealing with the level of saturation—for a trigram A … B → C, it’s unclear what happens if there are multiple copies of A, and the model can choose how to mediate this.
Toy example—let’s say attention to A and to BOS are all that matters (everything else is −1000). If BOS is 0 and A is 10, then the model doesn’t care if there are multiple As, it saturates immediately. If BOS is 5 and A is 0, then it’s now basically linear in the number of As. (In practice I expect it’ll be somewhere in the middle)
“context detection”—text tends to cluster into different contexts (books, wikipedia, etc), which have very different unigram and bigram statistics, and a model could learn to do the correct unigram updates for any destination token, conditional on a bunch of relevant source tokens being there.
There are probably some tokens that are way more common in some contexts than others (” yield”, ” return”, “\n\t\t” in Python code, etc), and a model could learn skip trigrams from any dest token to these source tokens (probably implemented via the query bias), whose OV circuit just boosts the unigram direction.
In some sense this is an ensemble of skip trigrams, I think it’s different because the natural way to implement it is to have the total attention paid to the special tokens saturate at uniform beyond a certain threshold of possible source tokens
Skip bigrams—when there’s a previous token head that acts semi-independently of the destination token, to implement A __ → C behaviour. I expect this arises for eg can|‘t| vs don|’t|, where the previous token significantly disambiguates the current token.