On the surprising effectiveness of linear regression as a toy model of generalization.
Another shortform today (since Sunday is the day of rest). This time it’s really a hot take: I’m not confident about the model described here being correct.
Neural networks aren’t linear—that’s the whole point. They notice interesting, compositional, deep information about reality. So when people use linear regression as a qualitative comparison point for behaviors like generalization and learning, I tend to get suspicious. Nevertheless, the track record of linear regression as a model for “qualitative” asymptotic behaviors is hard to deny. Linear regression models (neatly analyzable using random matrix theory) give surprisingly accurate models of double descent, scaling phenomena, etc. (at least when comparing to relatively shallow networks like mnist or modular addition).
I recently developed a cartoon internal model for why this may be the case. I’m not sure if it’s correct, but I’ll share it here.
The model assumes a few facts about algorithms implemented by NN’s (all of which I believe in much more strongly than the model of comparing them to linear regression):
The generalized Linear Representation Hypothesis. An NN’s internal working can be locally factorized into a large collection of low-level features in distinct low-dimensional linear subspaces, and then applying (generally nonlinear) postprocessing to these features independently or in small batches. Note that this is much weaker than a stronger version (such as the one inherent in SAE) that posits 1-dimensional features. In my experience a version of this hypothesis is almost universally believed by engineers, and also agrees with all known toy algorithms discovered so far.
Smoothness-ish of the data manifold. Inside the low-dimensional “feature subspace”, the data is kinda smooth—i.e., it’s smooth (i.e., locally approximately linear) in most directions, and in directions where it’s not smooth, it still might behave sorta smoothly in aggregate.
Linearity-ish of the classification signal. Even in cases like mnist or transformer learning where the training data is discrete (and the algorithm is mean to approximate it by a continuous function), there is a sense in which it’s locally well-approximated by a linear function. E.g. perhaps some coarse-graining of the discrete data is continuously linear, or at least the data boundary can be locally well approximated by a linear hyperplane (so that a local linear function can attain 100% accuracy). More generally, we can assume a similar local linearity property on the layer-to-layer forward functions, when restricted to either a single feature space or a small group of interacting feature spaces.
(At least partial) locality of the effect of weight modification. When I read it this paper left a lasting impression on me. I’m actually not super excited about its main claims (I’ll discuss polytopes later), but a very cool piece of the analysis here is locally modelling ReLU learning as building a convex function as a max of linear functions (and explaining why non-ReLU learning should exhibit a softer version of the same behavior). This is a somewhat “shallow” point of view on learning, but probably captures a nontrivial part of what’s going on, and this predicts that every new weight update only has local effect—i.e., is felt in a significant way only by a small number of datapoints (the idea being that if you’re defining a convex function as the max of a bunch of linear functions, shifting one of the linear functions will only change the values in places where this particular linear function was dominant). The way I think about this phenomenon is that it’s a good model for “local learning”, i.e., learning closer to memorization on the memorization-generalization spectrum that only updates the behavior on a small cluster of similar datapoints (e.g. the LLM circuit that completes “Barack” with “Obama”). There are also possibly also more diffuse phenomena (like “understanding logic”, or other forms of grokking “overarching structure”), but most likely both forms of learning occur (and it’s more like a spectrum than a dichotomy).
If we buy that these four phenomena occur, or even “occur sometimes” in a way relevant for learning, then it naturally follows that a part of the “shape” of learning and generalization is well described qualitatively by linear regression. Indeed, the model then becomes that (by point 4 above), many weight updates exclusively “focus on a single local batch of input points” in some low-dimensional feature manifold. For this particular weight update, locality of the update and smoothness of the data manifold (#2) together imply that we can model it as learning a function on a linear low-dimensional space (since smooth manifolds are locally well-approximated by a linear space). Finally, local linearity of the classification function (#3) implies that we’re learning a locally linear function on this local batch of datapoints. Thus we see that, under this collection of assumptions, the local learning subproblems essentially boil down to linear regression.
Note that the “low-dimensional feature space” assumption, #1, is necessary for any of this to even make sense. Without making this assumption, the whole picture is a non-starter and the other assumptions, #2-#4 don’t make sense, since a sub-exponentially large collection of points on a high-dimensional data manifold with any degree of randomness (something that is true about the data samples in any nontrivial learning problem) will be very far away from each other and the notion of “locality” becomes meaningless. (Note also that a weaker hypothesis than #1 would suffice—in particular, it’s enough that there are low-dimensional “feature mappings” where some clustering occurs at some layer, and these don’t a priori have to be linear.)
What is this model predicting? Generally I think abstract models like this aren’t very interesting until they make a falsifiable prediction or at least lead to some qualitative update on the behavior of NN’s. I haven’t thought about this very much, and would be excited if others have better ideas or can think of reasons why this model is incorrect. But one thing this model likely predicts is that a better model for a NN than a single linear regression model is a collection of qualitatively different linear regression models at different levels of granularity. In other words, depending on how sloppily you chop your data manifold up into feature subspaces, and how strongly you use the “locality” magnifying glass on each subspace, you’ll get a collection of different linear regression behaviors; you then predict that at every level of granularity, you will observe some combination of linear and nonlinear learning behaviors.
This point of view makes me excited about work by Ari Brill (that, as far as I know is unpublished—I heard a talk on it at the ILIAD conference—see the Saturday schedule, first talk in Bayes). If I understood the talk correctly, he models a data manifold as a certain stochastic fractal in a low-dimensional space and makes scaling predictions about generalization behavior depending on properties of the fractal, by thinking of the fractal as a hierarchy of smooth but noisy features. Finding similarly-flavored behavior scaling behavior on “linear regression subphenomena” in a real-life machine learning problem would positively update me on my model above being correct.
But one thing this model likely predicts is that a better model for a NN than a single linear regression model is a collection of qualitatively different linear regression models at different levels of granularity. In other words, depending on how sloppily you chop your data manifold up into feature subspaces, and how strongly you use the “locality” magnifying glass on each subspace, you’ll get a collection of different linear regression behaviors; you then predict that at every level of granularity, you will observe some combination of linear and nonlinear learning behaviors.
Epic.
A couple things that come to mind.
Linear features = sufficients statistics of exponential families ?
simplest case is case of Gaussians and covariance matrix (which comes down to linear regression)
exponential families are a fairly good class but not closed under hierarchichal structure. Basic example is a mixture of Gaussians is not exponential, i.e. not described in terms of just linear regression.
The centrality of ReLU neural networks.
Understanding ReLU neural networks is probably 80-90% of understanding NN- architectures. At sufficient scale pure MLP have the same or better scaling laws than transformers.
There is several lines of evidence gradient descent has an inherent bias towards splines/piecewise linear functions/tropical polynomials. see e.g. here and references therein.
Serious analysis of ReLU neural network can be done through tropical methods. A key paper is here. You say: “very cool piece of the analysis here is locally modelling ReLU learning as building a convex function as a max of linear functions (and explaining why non-ReLU learning should exhibit a softer version of the same behavior). This is a somewhat “shallow” point of view on learning, but probably captures a nontrivial part of what’s going on, and this predicts that every new weight update only has local effect—i.e., is felt in a significant way only by a small number of datapoints (the idea being that if you’re defining a convex function as the max of a bunch of linear functions, shifting one of the linear functions will only change the values in places where this particular linear function was dominant). The way I think about this phenomenon is that it’s a good model for “local learning”, i.e., learning closer to memorization on the memorization-generalization spectrum that only updates the behavior on a small cluster of similar datapoints (e.g. the LLM circuit that completes “Barack” with “Obama”). “ I suspect the notion one should be looking at are the Activation polytope and activation fan in section 5 of the paper. The hypothesis would be something about efficiently learnable features having a ‘locality’ constraint on these activation polytopes, ie. they are ‘small’, ‘active on only a few data points’..
On the surprising effectiveness of linear regression as a toy model of generalization.
Another shortform today (since Sunday is the day of rest). This time it’s really a hot take: I’m not confident about the model described here being correct.
Neural networks aren’t linear—that’s the whole point. They notice interesting, compositional, deep information about reality. So when people use linear regression as a qualitative comparison point for behaviors like generalization and learning, I tend to get suspicious. Nevertheless, the track record of linear regression as a model for “qualitative” asymptotic behaviors is hard to deny. Linear regression models (neatly analyzable using random matrix theory) give surprisingly accurate models of double descent, scaling phenomena, etc. (at least when comparing to relatively shallow networks like mnist or modular addition).
I recently developed a cartoon internal model for why this may be the case. I’m not sure if it’s correct, but I’ll share it here.
The model assumes a few facts about algorithms implemented by NN’s (all of which I believe in much more strongly than the model of comparing them to linear regression):
The generalized Linear Representation Hypothesis. An NN’s internal working can be locally factorized into a large collection of low-level features in distinct low-dimensional linear subspaces, and then applying (generally nonlinear) postprocessing to these features independently or in small batches. Note that this is much weaker than a stronger version (such as the one inherent in SAE) that posits 1-dimensional features. In my experience a version of this hypothesis is almost universally believed by engineers, and also agrees with all known toy algorithms discovered so far.
Smoothness-ish of the data manifold. Inside the low-dimensional “feature subspace”, the data is kinda smooth—i.e., it’s smooth (i.e., locally approximately linear) in most directions, and in directions where it’s not smooth, it still might behave sorta smoothly in aggregate.
Linearity-ish of the classification signal. Even in cases like mnist or transformer learning where the training data is discrete (and the algorithm is mean to approximate it by a continuous function), there is a sense in which it’s locally well-approximated by a linear function. E.g. perhaps some coarse-graining of the discrete data is continuously linear, or at least the data boundary can be locally well approximated by a linear hyperplane (so that a local linear function can attain 100% accuracy). More generally, we can assume a similar local linearity property on the layer-to-layer forward functions, when restricted to either a single feature space or a small group of interacting feature spaces.
(At least partial) locality of the effect of weight modification. When I read it this paper left a lasting impression on me. I’m actually not super excited about its main claims (I’ll discuss polytopes later), but a very cool piece of the analysis here is locally modelling ReLU learning as building a convex function as a max of linear functions (and explaining why non-ReLU learning should exhibit a softer version of the same behavior). This is a somewhat “shallow” point of view on learning, but probably captures a nontrivial part of what’s going on, and this predicts that every new weight update only has local effect—i.e., is felt in a significant way only by a small number of datapoints (the idea being that if you’re defining a convex function as the max of a bunch of linear functions, shifting one of the linear functions will only change the values in places where this particular linear function was dominant). The way I think about this phenomenon is that it’s a good model for “local learning”, i.e., learning closer to memorization on the memorization-generalization spectrum that only updates the behavior on a small cluster of similar datapoints (e.g. the LLM circuit that completes “Barack” with “Obama”). There are also possibly also more diffuse phenomena (like “understanding logic”, or other forms of grokking “overarching structure”), but most likely both forms of learning occur (and it’s more like a spectrum than a dichotomy).
If we buy that these four phenomena occur, or even “occur sometimes” in a way relevant for learning, then it naturally follows that a part of the “shape” of learning and generalization is well described qualitatively by linear regression. Indeed, the model then becomes that (by point 4 above), many weight updates exclusively “focus on a single local batch of input points” in some low-dimensional feature manifold. For this particular weight update, locality of the update and smoothness of the data manifold (#2) together imply that we can model it as learning a function on a linear low-dimensional space (since smooth manifolds are locally well-approximated by a linear space). Finally, local linearity of the classification function (#3) implies that we’re learning a locally linear function on this local batch of datapoints. Thus we see that, under this collection of assumptions, the local learning subproblems essentially boil down to linear regression.
Note that the “low-dimensional feature space” assumption, #1, is necessary for any of this to even make sense. Without making this assumption, the whole picture is a non-starter and the other assumptions, #2-#4 don’t make sense, since a sub-exponentially large collection of points on a high-dimensional data manifold with any degree of randomness (something that is true about the data samples in any nontrivial learning problem) will be very far away from each other and the notion of “locality” becomes meaningless. (Note also that a weaker hypothesis than #1 would suffice—in particular, it’s enough that there are low-dimensional “feature mappings” where some clustering occurs at some layer, and these don’t a priori have to be linear.)
What is this model predicting? Generally I think abstract models like this aren’t very interesting until they make a falsifiable prediction or at least lead to some qualitative update on the behavior of NN’s. I haven’t thought about this very much, and would be excited if others have better ideas or can think of reasons why this model is incorrect. But one thing this model likely predicts is that a better model for a NN than a single linear regression model is a collection of qualitatively different linear regression models at different levels of granularity. In other words, depending on how sloppily you chop your data manifold up into feature subspaces, and how strongly you use the “locality” magnifying glass on each subspace, you’ll get a collection of different linear regression behaviors; you then predict that at every level of granularity, you will observe some combination of linear and nonlinear learning behaviors.
This point of view makes me excited about work by Ari Brill (that, as far as I know is unpublished—I heard a talk on it at the ILIAD conference—see the Saturday schedule, first talk in Bayes). If I understood the talk correctly, he models a data manifold as a certain stochastic fractal in a low-dimensional space and makes scaling predictions about generalization behavior depending on properties of the fractal, by thinking of the fractal as a hierarchy of smooth but noisy features. Finding similarly-flavored behavior scaling behavior on “linear regression subphenomena” in a real-life machine learning problem would positively update me on my model above being correct.
Ari’s work is on Arxiv here
Loving this!
Epic.
A couple things that come to mind.
Linear features = sufficients statistics of exponential families ?
simplest case is case of Gaussians and covariance matrix (which comes down to linear regression)
formalized by GPD theorem
see generalization by John
exponential families are a fairly good class but not closed under hierarchichal structure. Basic example is a mixture of Gaussians is not exponential, i.e. not described in terms of just linear regression.
The centrality of ReLU neural networks.
Understanding ReLU neural networks is probably 80-90% of understanding NN- architectures. At sufficient scale pure MLP have the same or better scaling laws than transformers.
There is several lines of evidence gradient descent has an inherent bias towards splines/piecewise linear functions/tropical polynomials. see e.g. here and references therein.
Serious analysis of ReLU neural network can be done through tropical methods. A key paper is here. You say:
“very cool piece of the analysis here is locally modelling ReLU learning as building a convex function as a max of linear functions (and explaining why non-ReLU learning should exhibit a softer version of the same behavior). This is a somewhat “shallow” point of view on learning, but probably captures a nontrivial part of what’s going on, and this predicts that every new weight update only has local effect—i.e., is felt in a significant way only by a small number of datapoints (the idea being that if you’re defining a convex function as the max of a bunch of linear functions, shifting one of the linear functions will only change the values in places where this particular linear function was dominant). The way I think about this phenomenon is that it’s a good model for “local learning”, i.e., learning closer to memorization on the memorization-generalization spectrum that only updates the behavior on a small cluster of similar datapoints (e.g. the LLM circuit that completes “Barack” with “Obama”). “
I suspect the notion one should be looking at are the Activation polytope and activation fan in section 5 of the paper. The hypothesis would be something about efficiently learnable features having a ‘locality’ constraint on these activation polytopes, ie. they are ‘small’, ‘active on only a few data points’..