Reposting from a shortform post but I’ve been thinking about a possible additional argument that networks end up linear that I’d like some feedback on:
the tldr is that overcomplete bases necessitate linear representations
Neural networks use overcomplete bases to represent concepts. Especially in vector spaces without non-linearity, such as the transformer’s residual stream, there are just many more things that are stored in there than there are dimensions, and as Johnson Lindenstrauss shows, there are exponentially many almost-orthogonal directions to store them in (of course, we can’t assume that they’re stored linearly as directions, but if they were then there’s lots of space). (see also Toy models of transformers, sparsecodingwork)
Many different concepts may be active at once, and the model’s ability to read a representation needs to be robust to this kind of interference.
Highly non-linear information storage is going to be very fragile to interference because, by the definition of non-linearity, the model will respond differently to the input depending on the existing level of that feature. For example, if the response is quadratic or higher in the feature direction, then the impact of turning that feature on will be much different depending on whether certain not-quite orthogonal features are also on. If feature spaces are somehow curved then they will be similarly sensitive.
Of course linear representations will still be sensitive to this kind of interferences but I suspect there’s a mathematical proof for why linear features are the most robust to represent information in this kind of situation but I’m not sure where to look for existing work or how to start trying to prove it.
This is an interesting idea. I feel this also has to be related to increasing linearity with scale and generalization ability—i.e. if you have a memorised solution, then nonlinear representations are fine because you can easily tune the ‘boundaries’ of the nonlinear representation to precisely delineate the datapoints (in fact the nonlinearity of the representation can be used to strongly reduce interference when memorising as is done in the recent research on modern hopfield networks) . On the other hand, if you require a kind of reasonably large-scale smoothness of the solution space, as you would expect from a generalising solution in a flat basin, then this cannot work and you need to accept interference between nearly orthogonal features as the cost of preserving generalisation of the behaviour across many different inputs which activate the same vector.
Yes that makes a lot of sense that linearity would come hand in hand with generalization. I’d recently been reading Krotov on non-linear Hopfield networks but hadn’t made the connection. They say that they’re planning on using them to create more theoretically grounded transformer architectures. and your comment makes me think that these wouldn’t succeed but then the article also says:
This idea has been further extended in 2017 by showing that a careful choice of the activation function can even lead to an exponential memory storage capacity. Importantly, the study also demonstrated that dense associative memory, like the traditional Hopfield network, has large basins of attraction of size O(Nf). This means that the new model continues to benefit from strong associative properties despite the dense packing of memories inside the feature space.
which perhaps corresponds to them also being able to find good linear representation and to mix generalization and memorization like a transformer?
Reposting from a shortform post but I’ve been thinking about a possible additional argument that networks end up linear that I’d like some feedback on:
the tldr is that overcomplete bases necessitate linear representations
Neural networks use overcomplete bases to represent concepts. Especially in vector spaces without non-linearity, such as the transformer’s residual stream, there are just many more things that are stored in there than there are dimensions, and as Johnson Lindenstrauss shows, there are exponentially many almost-orthogonal directions to store them in (of course, we can’t assume that they’re stored linearly as directions, but if they were then there’s lots of space). (see also Toy models of transformers, sparse coding work)
Many different concepts may be active at once, and the model’s ability to read a representation needs to be robust to this kind of interference.
Highly non-linear information storage is going to be very fragile to interference because, by the definition of non-linearity, the model will respond differently to the input depending on the existing level of that feature. For example, if the response is quadratic or higher in the feature direction, then the impact of turning that feature on will be much different depending on whether certain not-quite orthogonal features are also on. If feature spaces are somehow curved then they will be similarly sensitive.
Of course linear representations will still be sensitive to this kind of interferences but I suspect there’s a mathematical proof for why linear features are the most robust to represent information in this kind of situation but I’m not sure where to look for existing work or how to start trying to prove it.
This is an interesting idea. I feel this also has to be related to increasing linearity with scale and generalization ability—i.e. if you have a memorised solution, then nonlinear representations are fine because you can easily tune the ‘boundaries’ of the nonlinear representation to precisely delineate the datapoints (in fact the nonlinearity of the representation can be used to strongly reduce interference when memorising as is done in the recent research on modern hopfield networks) . On the other hand, if you require a kind of reasonably large-scale smoothness of the solution space, as you would expect from a generalising solution in a flat basin, then this cannot work and you need to accept interference between nearly orthogonal features as the cost of preserving generalisation of the behaviour across many different inputs which activate the same vector.
Yes that makes a lot of sense that linearity would come hand in hand with generalization. I’d recently been reading Krotov on non-linear Hopfield networks but hadn’t made the connection. They say that they’re planning on using them to create more theoretically grounded transformer architectures. and your comment makes me think that these wouldn’t succeed but then the article also says:
which perhaps corresponds to them also being able to find good linear representation and to mix generalization and memorization like a transformer?