There’s an argument that I’ve been thinking about which I’d really like some feedback or pointers to literature on:
the tldr is that overcomplete bases necessitate linear representations
Neural networks use overcomplete bases to represent concepts. Especially in vector spaces without non-linearity, such as the transformer’s residual stream, there are just many more things that are stored in there than there are dimensions, and as Johnson Lindenstrauss shows, there are exponentially many almost-orthogonal directions to store them in (of course, we can’t assume that they’re stored linearly as directions, but if they were then there’s lots of space). (see also Toy models of transformers, my sparse autoencoder posts)
Many different concepts may be active at once, and the model’s ability to read a representation needs to be robust to this kind of interference.
Highly non-linear information storage is going to be very fragile to interference because, by the definition of non-linearity, the model will respond differently to the input depending on the existing level of that feature. For example, if the response is quadratic or higher in the feature direction, then the impact of turning that feature on will be much different depending on whether certain not-quite orthogonal features are also on. If feature spaces are somehow curved then they will be similarly sensitive.
Of course linear representations will still be sensitive to this kind of interferences but I suspect there’s a mathematical proof for why linear features are the most robust to represent information in this kind of situation but I’m not sure where to look for existing work or how to start trying to prove it.
There’s an argument that I’ve been thinking about which I’d really like some feedback or pointers to literature on:
the tldr is that overcomplete bases necessitate linear representations
Neural networks use overcomplete bases to represent concepts. Especially in vector spaces without non-linearity, such as the transformer’s residual stream, there are just many more things that are stored in there than there are dimensions, and as Johnson Lindenstrauss shows, there are exponentially many almost-orthogonal directions to store them in (of course, we can’t assume that they’re stored linearly as directions, but if they were then there’s lots of space). (see also Toy models of transformers, my sparse autoencoder posts)
Many different concepts may be active at once, and the model’s ability to read a representation needs to be robust to this kind of interference.
Highly non-linear information storage is going to be very fragile to interference because, by the definition of non-linearity, the model will respond differently to the input depending on the existing level of that feature. For example, if the response is quadratic or higher in the feature direction, then the impact of turning that feature on will be much different depending on whether certain not-quite orthogonal features are also on. If feature spaces are somehow curved then they will be similarly sensitive.
Of course linear representations will still be sensitive to this kind of interferences but I suspect there’s a mathematical proof for why linear features are the most robust to represent information in this kind of situation but I’m not sure where to look for existing work or how to start trying to prove it.