Another key note about mamba is that despite being RNN-like it doesn’t result in substantially higher effective serial reasoning depth (relative to transformers). This is because the state transition is linear[1]. However, it is architecturally closer to things that might involve effectively higher depth.
And indeed, there is a fundamental tradeoff where if the state transition function is expressive (e.g. nonlinear), then it would no longer be possible to use a parallel scan because the intermediates for the scan would be too large to represent compactly or wouldn’t simplify the original functions to reduce computation. You can’t compactly represent f∘g (f composed with g) in a way that makes computing f(g(x)) more efficient for general choices of f and g (in the typical MLP case at least). Another simpler but less illuminating way to put this is that higher serial reasoning depth can’t be parallelized (without imposing some constraints on the serial reasoning).
Another simpler but less illuminating way to put this is that higher serial reasoning depth can’t be parallelized.[1]
Transformers do get more computation per token on longer sequences, but they also don’t get more serial depth, so I’m not sure if this is actually an issue in practice?
[C]ompactly represent f∘g (f composed with g) in a way that makes computing f(g(x)) more efficient for general choices of f and g.
As an aside, I actually can’t think of any class of interesting functions with this property—when reading the paper, the closest I could think of are functions on discrete sets (lol), polynomials (but simplifying these are often more expensive than just computing the terms serially), and rational functions (ditto)
Another key note about mamba is that despite being RNN-like it doesn’t result in substantially higher effective serial reasoning depth (relative to transformers). This is because the state transition is linear[1]. However, it is architecturally closer to things that might involve effectively higher depth.
See also here.
And indeed, there is a fundamental tradeoff where if the state transition function is expressive (e.g. nonlinear), then it would no longer be possible to use a parallel scan because the intermediates for the scan would be too large to represent compactly or wouldn’t simplify the original functions to reduce computation. You can’t compactly represent f∘g (f composed with g) in a way that makes computing f(g(x)) more efficient for general choices of f and g (in the typical MLP case at least). Another simpler but less illuminating way to put this is that higher serial reasoning depth can’t be parallelized (without imposing some constraints on the serial reasoning).
I mean, yeah, as your footnote says:
Transformers do get more computation per token on longer sequences, but they also don’t get more serial depth, so I’m not sure if this is actually an issue in practice?
As an aside, I actually can’t think of any class of interesting functions with this property—when reading the paper, the closest I could think of are functions on discrete sets (lol), polynomials (but simplifying these are often more expensive than just computing the terms serially), and rational functions (ditto)