I finally got around to reading the Mamba paper. H/t Ryan Greenblatt and Vivek Hebbar for helpful comments that got me unstuck.
TL;DR: authors propose a new deep learning architecture for sequence modeling with scaling laws that match transformers while being much more efficient to sample from.
A brief historical digression
As of ~2017, the three primary ways people had for doing sequence modeling were RNNs, Conv Nets, and Transformers, each with a unique “trick” for handling sequence data: recurrence, 1d convolutions, and self-attention.
RNNs are easy to sample from — to compute the logit for x_t+1, you only need the most recent hidden state h_t and the last token x_t, which means it’s both fast and memory efficient. RNNs generate a sequence of length L with O(1) memory and O(L) time. However, they’re super hard to train, because you need to sequentially generate all the hidden states and then (reverse) sequentially calculate the gradients. The way you actually did this is called backpropogation through time — you basically unroll the RNN over time — which requires constructing a graph of depth equal to the sequence length. Not only was this slow, but the graph being so deep caused vanishing/exploding gradients without careful normalization. The strategy that people used was to train on short sequences and finetune on longer ones. That being said, in practice, this meant you couldn’t train on long sequences (>a few hundred tokens) at all. The best LSTMs for modeling raw audio could only handle being trained on ~5s of speech, if you chunk up the data into 25ms segments.
Conv Nets had a fixed receptive field size and pattern, so weren’t that suited for long sequence modeling. Also, generating each token takes O(L) time, assuming the receptive field is about the same size as the sequence. But they had significantly more stability (the depth was small, and could be as low as O(log(L))), which meant you could train them a lot easier. (Also, you could use FFT to efficiently compute the conv, meaning it trains one sequence in O(L log(L)) time.) That being said, you still couldn’t make them that big. The most impressive example was DeepMind’s WaveNet, conv net used to model human speech, and could handle up sequences up to 4800 samples … which was 0.3s of actual speech at 16k samples/second (note that most audio is sampled at 44k samples/second…), and even to to get to that amount, they had to really gimp the model’s ability to focus on particular inputs.
Transformers are easy to train, can handle variable length sequences, and also allow the model to “decide” which tokens it should pay attention to. In addition to both being parallelizable and having relatively shallow computation graphs (like conv nets), you could do the RNN trick of pretraining on short sequences and then finetune on longer sequences to save even more compute. Transformers could be trained with comparable sequence length to conv nets but get much better performance; for example, OpenAI’s musenet was trained on sequence length 4096 sequences of MIDI files. But as we all know, transformers have the unfortunate downside of being expensive to sample from — it takes O(L) time and O(L) memory to generate a single token (!).
The better performance of transformers over conv nets and their ability to handle variable length data let them win out.
That being said, people have been trying to get around the O(L) time and memory requirements for transformers since basically their inception. For a while, people were super into sparse or linear attention of various kinds, which could reduce the per-token compute/memory requirements to O(log(L)) or O(1).
The what and why of Mamba
If the input → hidden and hidden → hidden map for RNNs were linear (h_t+1 = A h_t + B x_t), then it’d be possible to train an entire sequence in parallel — this is because you can just … compose the transformation with itself (computing A^k for k in 2…L-1) a bunch, and effectively unroll the graph with the convolutional kernel defined by A B, A^2 B, A^3 B, … A^{L-1} B. Not only can you FFT during training to get the O(L log (L)) time of a conv net forward/backward pass (as opposed to O(L^2) for the transformer), you still keep the O(1) sampling time/memory of the RNN!
The problem is that linear hidden state dynamics are kinda boring. For example, you can’t even learn to update your existing hidden state in a different way if you see particular tokens! And indeed, previous results gave scaling laws that were much worse than transformers in terms of performance/training compute.
In Mamba, you basically learn a time varying A and B. The parameterization is a bit wonky here, because of historical reasons, but it goes something like: A_t is exp(-\delta(x_t) * exp(A)), B_t = \delta(x_t) B x_t, where \delta(x_t) = softplus ( W_\delta x_t). Also note that in Mamba, they also constrain A to be diagonal and W_\delta to be low rank, for computational reasons
Since exp(A) is diagonal and has only positive entries, we can interpret the model as follows: \delta controls how much to “learn” from the current example — with high \delta, A_t approaches 0 and B_t is large, causing h_t+1 ~= B_t x_t, while with \delta approaching 0, A_t approaches 1 and B_t approaches 0, meaning h_t+1 ~= h_t.
Now, you can’t exactly unroll the hidden state as a convolution with a predefined convolution kernel anymore, but you can still efficiently compute the implied “convolution” using parallel scanning.
Despite being much cheaper to sample from, Mamba matches the pretraining flops efficiency of modern transformers (Transformer++ = the current SOTA open source Transformer with RMSNorm, a better learning rate schedule, and corrected AdamW hyperparameters, etc.). And on a toy induction task, it generalizes to much longer sequences than it was trained on.
So, about those capability externalities from mech interp...
Yes, those are the same induction heads from the Anthropic ICL paper!
Like the previous Hippo and Hyena papers they cite mech interp as one of their inspirations, in that it inspired them to think about what the linear hidden state model could not model and how to fix that. I still don’t think mech interp has that much Shapley here (the idea of studying how models perform toy tasks is not new, and the authors don’t even use induction metric or RRT task from the Olsson et al paper), but I’m not super sure on this.
IMO, this is line of work is the strongest argument for mech interp (or maybe interp in general) having concrete capabilities externalities. In addition, I think the previous argument Neel and I gave of “these advances are extremely unlikely to improve frontier models” feels substantially weaker now.
Another key note about mamba is that despite being RNN-like it doesn’t result in substantially higher effective serial reasoning depth (relative to transformers). This is because the state transition is linear[1]. However, it is architecturally closer to things that might involve effectively higher depth.
And indeed, there is a fundamental tradeoff where if the state transition function is expressive (e.g. nonlinear), then it would no longer be possible to use a parallel scan because the intermediates for the scan would be too large to represent compactly or wouldn’t simplify the original functions to reduce computation. You can’t compactly represent f∘g (f composed with g) in a way that makes computing f(g(x)) more efficient for general choices of f and g (in the typical MLP case at least). Another simpler but less illuminating way to put this is that higher serial reasoning depth can’t be parallelized (without imposing some constraints on the serial reasoning).
Another simpler but less illuminating way to put this is that higher serial reasoning depth can’t be parallelized.[1]
Transformers do get more computation per token on longer sequences, but they also don’t get more serial depth, so I’m not sure if this is actually an issue in practice?
[C]ompactly represent f∘g (f composed with g) in a way that makes computing f(g(x)) more efficient for general choices of f and g.
As an aside, I actually can’t think of any class of interesting functions with this property—when reading the paper, the closest I could think of are functions on discrete sets (lol), polynomials (but simplifying these are often more expensive than just computing the terms serially), and rational functions (ditto)
The induction heads extrapolation graph is a bit cherry picked/misleading IMO because it’s still the case that mamba can’t copy an arbitrary amount of text. (It can copy a some fixed finite number of tokens if it is predictable that those tokens will be copied over other tokens.)
E.g., if you have N tokens repeated twice, mamba will fail to get perfect loss on the second repetition for sufficiently large N. To see this, note that the total hidden state is bounded so eventually there will be enough text to fill this state. This isn’t true for transformers.
It’s unclear how much of an obstacle this is in practice, but it hints at ways in which transformers might be relatively fundamentally more capacity efficient. Note that this issues also applies to some sparse attention mechanisms like sliding window attention, but any attention mechanism that stores state for the entire sequence should be able to avoid this issue.
(This is a well known result, though I forget the current state of empirical results.)
‘Our theoretical analysis reveals that CoT improves RNNs but is insufficient to close the gap with Transformers. A key bottleneck lies in the inability of RNNs to perfectly retrieve information from the context, even with CoT: for several tasks that explicitly or implicitly require this capability, such as associative recall and determining if a graph is a tree, we prove that RNNs are not expressive enough to solve the tasks while Transformers can solve them with ease. Conversely, we prove that adopting techniques to enhance the in-context retrieval capability of RNNs, including Retrieval-Augmented Generation (RAG) and adding a single Transformer layer, can elevate RNNs to be capable of solving all polynomial-time solvable problems with CoT, hence closing the representation gap with Transformers.’
StripedHyena, Griffin, and especially Based suggest that combining RNN-like layers with even tiny sliding window attention might be a robust way of getting a large context, where the RNN-like layers don’t have to be as good as Mamba for the combination to work. There is a great variety of RNN-like blocks that haven’t been evaluated for hybridization with sliding window attention specifically, as in Griffin and Based. Some of them might turn out better than Mamba on scaling laws after hybridization, so Mamba being impressive without hybridization might be less important than this general point.
(Possibly a window of precise attention gives the RNN layers many attempts at both storing and retrieving any given observation, so interspersing layes with even a relatively tiny window is sufficient to significantly improve on the more sloppy RNN-like model without any attention, whereas a pure RNN-like model would have to capture what it needs from a token in the exact step it appears, and then the opportunity is mostly lost. StripedHyena’s attention wasn’t sliding window, so context didn’t scale any better, but with sliding window attention there are no context scaling implications from the attention layers.)
this line of work is the strongest argument for mech interp [...] having concrete capabilities externalities
I have found this claim a bit handwavy, as I could imagine state space models being invented and improved to the current stage without the prior work of mech interp. More fundamentally, just “being inspired by” is not a quantitative claim after all, and mech interp is not the central idea here anyway.
On the other hand, though, much of the (shallow) interp can help with capabilities more directly, especially on inference speed. Recent examples I can think of are Attention Sinks, Activation Sparsity, Deja Vu, and several parallel and follow-up works. (Sudden Drops also has some evidence on improving training dynamics using insights from developmental interp, though I think it’s somewhat weak.)
I finally got around to reading the Mamba paper. H/t Ryan Greenblatt and Vivek Hebbar for helpful comments that got me unstuck.
TL;DR: authors propose a new deep learning architecture for sequence modeling with scaling laws that match transformers while being much more efficient to sample from.
A brief historical digression
As of ~2017, the three primary ways people had for doing sequence modeling were RNNs, Conv Nets, and Transformers, each with a unique “trick” for handling sequence data: recurrence, 1d convolutions, and self-attention.
RNNs are easy to sample from — to compute the logit for x_t+1, you only need the most recent hidden state h_t and the last token x_t, which means it’s both fast and memory efficient. RNNs generate a sequence of length L with O(1) memory and O(L) time. However, they’re super hard to train, because you need to sequentially generate all the hidden states and then (reverse) sequentially calculate the gradients. The way you actually did this is called backpropogation through time — you basically unroll the RNN over time — which requires constructing a graph of depth equal to the sequence length. Not only was this slow, but the graph being so deep caused vanishing/exploding gradients without careful normalization. The strategy that people used was to train on short sequences and finetune on longer ones. That being said, in practice, this meant you couldn’t train on long sequences (>a few hundred tokens) at all. The best LSTMs for modeling raw audio could only handle being trained on ~5s of speech, if you chunk up the data into 25ms segments.
Conv Nets had a fixed receptive field size and pattern, so weren’t that suited for long sequence modeling. Also, generating each token takes O(L) time, assuming the receptive field is about the same size as the sequence. But they had significantly more stability (the depth was small, and could be as low as O(log(L))), which meant you could train them a lot easier. (Also, you could use FFT to efficiently compute the conv, meaning it trains one sequence in O(L log(L)) time.) That being said, you still couldn’t make them that big. The most impressive example was DeepMind’s WaveNet, conv net used to model human speech, and could handle up sequences up to 4800 samples … which was 0.3s of actual speech at 16k samples/second (note that most audio is sampled at 44k samples/second…), and even to to get to that amount, they had to really gimp the model’s ability to focus on particular inputs.
Transformers are easy to train, can handle variable length sequences, and also allow the model to “decide” which tokens it should pay attention to. In addition to both being parallelizable and having relatively shallow computation graphs (like conv nets), you could do the RNN trick of pretraining on short sequences and then finetune on longer sequences to save even more compute. Transformers could be trained with comparable sequence length to conv nets but get much better performance; for example, OpenAI’s musenet was trained on sequence length 4096 sequences of MIDI files. But as we all know, transformers have the unfortunate downside of being expensive to sample from — it takes O(L) time and O(L) memory to generate a single token (!).
The better performance of transformers over conv nets and their ability to handle variable length data let them win out.
That being said, people have been trying to get around the O(L) time and memory requirements for transformers since basically their inception. For a while, people were super into sparse or linear attention of various kinds, which could reduce the per-token compute/memory requirements to O(log(L)) or O(1).
The what and why of Mamba
If the input → hidden and hidden → hidden map for RNNs were linear (h_t+1 = A h_t + B x_t), then it’d be possible to train an entire sequence in parallel — this is because you can just … compose the transformation with itself (computing A^k for k in 2…L-1) a bunch, and effectively unroll the graph with the convolutional kernel defined by A B, A^2 B, A^3 B, … A^{L-1} B. Not only can you FFT during training to get the O(L log (L)) time of a conv net forward/backward pass (as opposed to O(L^2) for the transformer), you still keep the O(1) sampling time/memory of the RNN!
The problem is that linear hidden state dynamics are kinda boring. For example, you can’t even learn to update your existing hidden state in a different way if you see particular tokens! And indeed, previous results gave scaling laws that were much worse than transformers in terms of performance/training compute.
In Mamba, you basically learn a time varying A and B. The parameterization is a bit wonky here, because of historical reasons, but it goes something like: A_t is exp(-\delta(x_t) * exp(A)), B_t = \delta(x_t) B x_t, where \delta(x_t) = softplus ( W_\delta x_t). Also note that in Mamba, they also constrain A to be diagonal and W_\delta to be low rank, for computational reasons
Since exp(A) is diagonal and has only positive entries, we can interpret the model as follows: \delta controls how much to “learn” from the current example — with high \delta, A_t approaches 0 and B_t is large, causing h_t+1 ~= B_t x_t, while with \delta approaching 0, A_t approaches 1 and B_t approaches 0, meaning h_t+1 ~= h_t.
Now, you can’t exactly unroll the hidden state as a convolution with a predefined convolution kernel anymore, but you can still efficiently compute the implied “convolution” using parallel scanning.
Despite being much cheaper to sample from, Mamba matches the pretraining flops efficiency of modern transformers (Transformer++ = the current SOTA open source Transformer with RMSNorm, a better learning rate schedule, and corrected AdamW hyperparameters, etc.). And on a toy induction task, it generalizes to much longer sequences than it was trained on.
So, about those capability externalities from mech interp...
Yes, those are the same induction heads from the Anthropic ICL paper!
Like the previous Hippo and Hyena papers they cite mech interp as one of their inspirations, in that it inspired them to think about what the linear hidden state model could not model and how to fix that. I still don’t think mech interp has that much Shapley here (the idea of studying how models perform toy tasks is not new, and the authors don’t even use induction metric or RRT task from the Olsson et al paper), but I’m not super sure on this.
IMO, this is line of work is the strongest argument for mech interp (or maybe interp in general) having concrete capabilities externalities. In addition, I think the previous argument Neel and I gave of “these advances are extremely unlikely to improve frontier models” feels substantially weaker now.
Is this a big deal?
I don’t know, tbh.
Another key note about mamba is that despite being RNN-like it doesn’t result in substantially higher effective serial reasoning depth (relative to transformers). This is because the state transition is linear[1]. However, it is architecturally closer to things that might involve effectively higher depth.
See also here.
And indeed, there is a fundamental tradeoff where if the state transition function is expressive (e.g. nonlinear), then it would no longer be possible to use a parallel scan because the intermediates for the scan would be too large to represent compactly or wouldn’t simplify the original functions to reduce computation. You can’t compactly represent f∘g (f composed with g) in a way that makes computing f(g(x)) more efficient for general choices of f and g (in the typical MLP case at least). Another simpler but less illuminating way to put this is that higher serial reasoning depth can’t be parallelized (without imposing some constraints on the serial reasoning).
I mean, yeah, as your footnote says:
Transformers do get more computation per token on longer sequences, but they also don’t get more serial depth, so I’m not sure if this is actually an issue in practice?
As an aside, I actually can’t think of any class of interesting functions with this property—when reading the paper, the closest I could think of are functions on discrete sets (lol), polynomials (but simplifying these are often more expensive than just computing the terms serially), and rational functions (ditto)
The induction heads extrapolation graph is a bit cherry picked/misleading IMO because it’s still the case that mamba can’t copy an arbitrary amount of text. (It can copy a some fixed finite number of tokens if it is predictable that those tokens will be copied over other tokens.)
E.g., if you have N tokens repeated twice, mamba will fail to get perfect loss on the second repetition for sufficiently large N. To see this, note that the total hidden state is bounded so eventually there will be enough text to fill this state. This isn’t true for transformers.
It’s unclear how much of an obstacle this is in practice, but it hints at ways in which transformers might be relatively fundamentally more capacity efficient. Note that this issues also applies to some sparse attention mechanisms like sliding window attention, but any attention mechanism that stores state for the entire sequence should be able to avoid this issue.
(This is a well known result, though I forget the current state of empirical results.)
Seems relevant—RNNs are not Transformers (Yet): The Key Bottleneck on In-context Retrieval:
‘Our theoretical analysis reveals that CoT improves RNNs but is insufficient to close the gap with Transformers. A key bottleneck lies in the inability of RNNs to perfectly retrieve information from the context, even with CoT: for several tasks that explicitly or implicitly require this capability, such as associative recall and determining if a graph is a tree, we prove that RNNs are not expressive enough to solve the tasks while Transformers can solve them with ease. Conversely, we prove that adopting techniques to enhance the in-context retrieval capability of RNNs, including Retrieval-Augmented Generation (RAG) and adding a single Transformer layer, can elevate RNNs to be capable of solving all polynomial-time solvable problems with CoT, hence closing the representation gap with Transformers.’
StripedHyena, Griffin, and especially Based suggest that combining RNN-like layers with even tiny sliding window attention might be a robust way of getting a large context, where the RNN-like layers don’t have to be as good as Mamba for the combination to work. There is a great variety of RNN-like blocks that haven’t been evaluated for hybridization with sliding window attention specifically, as in Griffin and Based. Some of them might turn out better than Mamba on scaling laws after hybridization, so Mamba being impressive without hybridization might be less important than this general point.
(Possibly a window of precise attention gives the RNN layers many attempts at both storing and retrieving any given observation, so interspersing layes with even a relatively tiny window is sufficient to significantly improve on the more sloppy RNN-like model without any attention, whereas a pure RNN-like model would have to capture what it needs from a token in the exact step it appears, and then the opportunity is mostly lost. StripedHyena’s attention wasn’t sliding window, so context didn’t scale any better, but with sliding window attention there are no context scaling implications from the attention layers.)
I have found this claim a bit handwavy, as I could imagine state space models being invented and improved to the current stage without the prior work of mech interp. More fundamentally, just “being inspired by” is not a quantitative claim after all, and mech interp is not the central idea here anyway.
On the other hand, though, much of the (shallow) interp can help with capabilities more directly, especially on inference speed. Recent examples I can think of are Attention Sinks, Activation Sparsity, Deja Vu, and several parallel and follow-up works. (Sudden Drops also has some evidence on improving training dynamics using insights from developmental interp, though I think it’s somewhat weak.)
Yeah, “strongest” doesn’t mean “strong” here!
Found this helpful, thanks!