In this toy model, is it really the case that the datapoint feature solutions are “more memorizing, less generalizing” than the axis-aligned feature solutions? I don’t feel totally convinced of this.
Two ways to look at the toy problem:
There are N sparse features, one per input and output channel.
There are T sparse features, one per data point, and each one is active only on its data point. The features are related to the input basis by some matrix M∈RN×T.
There are some details of the toy model that put (2) on a “different footing” from (1).
Since the input and output use the same basis, if we make a change of basis, we have to change back again at the end. And because the weights are tied, these two operations have to be transposes, i.e. the change of basis has to be a rotation.
As illustrated in the Colab, requiring the data to be orthonormal is sufficient for this. The experiment constrained the data to unit norm, and it’s close to orthogonal with high probability for T≪N.
Now, it happens that (1) is the true data-generating process, but the model has no way of guessing that. In the finite-data case, the data may be consistent with multiple data-generating processes, and a solution that generalizes well with respect to one of them may generalize poorly with respect to another.
To designate one data-generating process as the relevant one for generalization, we have to make a value judgment about which hypotheses are better, among those that explain the data equally well.
In particular, when T<N, hypothesis (2) seems more parsimonious than hypothesis (1): it explains the data just as well with fewer features! The features aren’t axis-aligned like in (1), but features in real problems won’t be axis-aligned either.
In some sense, it does feel like there’s a suspicious lack of generalization in (2). Namely, that no generalization is made between the training examples: any knowledge you gain about a feature from seeing one example will go unused on the rest of the training set. But if your data is small enough that is almost entirely orthogonal, hypothesis (1) has the same problem: the feature weight in each training example has almost no overlap with the other examples.
In this toy model, is it really the case that the datapoint feature solutions are “more memorizing, less generalizing” than the axis-aligned feature solutions? I don’t feel totally convinced of this.
Well, empirically in this setup, (1) does generalize and get a lower test loss than (2). In fact, it’s the only version that does better than random. 🙂
But I think what you’re maybe saying is that from the neural network’s perspective, (2) is a very reasonable hypothesis when T < N, regardless of what is true in this specific setup. And you could perhaps imagine other data generating processes which would look similar for small data sets, but generalize differently. I think there’s something to that, and it depends a lot on your intuitions about what natural data is like.
Some important intuitions for me are:
Many natural language features are extremely sparse. For example, it seems likely LLMs probably have features for particular people, for particular street intersections, for specific restaurants… Each of these features is very very rarely occurring (many are probably present less than 1 in 10 million tokens).
Simultaneously, there are an enormous number of features (see above!).
While the datasets aren’t actually small, repeated data points effectively make many data points behave like we’re in the small data regime (see Adam’s repeated data experiment).
Thus, my intuition is that something directionally like this setup—having a large number of extremely sparse features—and then studying how representations change with dataset size is quite relevant. But that’s all just based on intuition!
(By the way, I think there is a very deep observations about the duality of (1) vs (2) and T<N. See the observations about duality in https://arxiv.org/pdf/2210.16859.pdf )
Interesting stuff!
In this toy model, is it really the case that the datapoint feature solutions are “more memorizing, less generalizing” than the axis-aligned feature solutions? I don’t feel totally convinced of this.
Two ways to look at the toy problem:
There are N sparse features, one per input and output channel.
There are T sparse features, one per data point, and each one is active only on its data point. The features are related to the input basis by some matrix M∈RN×T.
There are some details of the toy model that put (2) on a “different footing” from (1).
Since the input and output use the same basis, if we make a change of basis, we have to change back again at the end. And because the weights are tied, these two operations have to be transposes, i.e. the change of basis has to be a rotation.
As illustrated in the Colab, requiring the data to be orthonormal is sufficient for this. The experiment constrained the data to unit norm, and it’s close to orthogonal with high probability for T≪N.
Now, it happens that (1) is the true data-generating process, but the model has no way of guessing that. In the finite-data case, the data may be consistent with multiple data-generating processes, and a solution that generalizes well with respect to one of them may generalize poorly with respect to another.
To designate one data-generating process as the relevant one for generalization, we have to make a value judgment about which hypotheses are better, among those that explain the data equally well.
In particular, when T<N, hypothesis (2) seems more parsimonious than hypothesis (1): it explains the data just as well with fewer features! The features aren’t axis-aligned like in (1), but features in real problems won’t be axis-aligned either.
In some sense, it does feel like there’s a suspicious lack of generalization in (2). Namely, that no generalization is made between the training examples: any knowledge you gain about a feature from seeing one example will go unused on the rest of the training set. But if your data is small enough that is almost entirely orthogonal, hypothesis (1) has the same problem: the feature weight in each training example has almost no overlap with the other examples.
Well, empirically in this setup, (1) does generalize and get a lower test loss than (2). In fact, it’s the only version that does better than random. 🙂
But I think what you’re maybe saying is that from the neural network’s perspective, (2) is a very reasonable hypothesis when T < N, regardless of what is true in this specific setup. And you could perhaps imagine other data generating processes which would look similar for small data sets, but generalize differently. I think there’s something to that, and it depends a lot on your intuitions about what natural data is like.
Some important intuitions for me are:
Many natural language features are extremely sparse. For example, it seems likely LLMs probably have features for particular people, for particular street intersections, for specific restaurants… Each of these features is very very rarely occurring (many are probably present less than 1 in 10 million tokens).
Simultaneously, there are an enormous number of features (see above!).
While the datasets aren’t actually small, repeated data points effectively make many data points behave like we’re in the small data regime (see Adam’s repeated data experiment).
Thus, my intuition is that something directionally like this setup—having a large number of extremely sparse features—and then studying how representations change with dataset size is quite relevant. But that’s all just based on intuition!
(By the way, I think there is a very deep observations about the duality of (1) vs (2) and T<N. See the observations about duality in https://arxiv.org/pdf/2210.16859.pdf )