This does a great job of importing and translating a set of intuitions from a much more established and rigorous field. However, as with all works framing deep learning as a particular instance of some well-studied problem, it’s vital to keep the context in mind:
Despite literally thousands of papers claiming to “understand deep learning” from experts in fields as various as computational complexity, compressed sensing, causal inference, and—yes—statistical learning, NO rigorous, first-principles analysis has ever computed any aspect of any deep learning model beyond toy settings. ALL published bounds are vacuous in practice.
It’s worth exploring why, despite strong results in their own settings, and despite strong “intuitive” parallels to deep learning, this remains true. The issue is that all these intuitive arguments have holes big enough to accommodate, well, the end of the world. There are several such challenges in establishing a tight correspondence between “classical machine learning” and deep learning, but I’ll focus on one that’s been the focus of considerable effort: defining simplicity.
This notion is essential. If we consider a truly arbitrary function, there is no need for a relationship between the behavior on one input and the behavior on another—the No Free Lunch Theorem. If we want our theory to have content (that is, to constrain the behavior of a Deep Learning system whatsoever) we’ll need to narrow the range of possibilities. Tools from statistical learning like the VC dimension are useless as is due to overparameterization, as you mention. We’ll need a notion of simplicity that captures what sorts of computational structures SGD finds in practice. Maybe circuit size, or minima sharpness, or noise sensitivity… - how hard could it be?
Well no one’s managed it. To help understand why, here are two (of many) barrier cases:
Sparse parity with noise. For an input bitstring x, y is defined as the xor of a fixed, small subset of indices. E.g if the indices are 1,3,9 and x is 101000001 then y is 1 xor 1 xor 1 = 1. Some small (tending to zero) measurement error is assumed. Though this problem is approximated almost perfectly by extremely small and simple boolean circuits (a log-depth tree of xor gates with inputs on the chosen subset), it is believed to require an exponential amount of computation to predict even marginally better than random! Neural networks require exponential size to learn it in practice. Deep Learning Fails
The protein folding problem. Predict the shape of a protein from its amino acid sequence. Hundreds of scientists have spent decades scouring for regularities, and failed. Generations of supercomputers have been built to attempt to simulate the subtle evolution of molecular structure, and failed. Billions of pharmaceutical dollars were invested—hundreds of billions were on the table for success. The data is noisy and multi-modal. Protein language models learn it all anyway. Deep Learning Succeeds!
For what notion is the first problem complicated, and the second simple?
Again, without such a notion, statistical learning theory makes no prediction whatsoever about the behavior of DL systems on new examples. If a model someday outputted a sequence of actions which caused the extinction of the human race, we couldn’t object on principle, only say “so power-seeking was simpler after all”. And even with such a notion, we’d still have to prove that Gradient Descent tends to find it in practice and a dozen other difficulties...
Without a precise mathematical framework to which we can defer, we’re left with Empirics to help us choose between a bunch of sloppy, spineless sets of intuitions. Much less pleasant. Still, here’s a few which push me towards Deep Learning as a “computationally general, pattern-finding process” rather than function approximation:
Neural networks optimized only for performance show surprising alignment with representations in the human brain, even exhibiting 1-1 matches between particular neurons in ANNs and living humans. This is an absolutely unprecedented level of predictivity, despite the models not being designed for such and taking no brain data as input.
LLM’s have been found to contain rich internal structure such as grammatical parse-trees, inference-time linear models, and world models. This sort of mechanistic picture is missing from any theory that considers only i/o
Small changes in loss (i.e. function approximation accuracy) have been associated with large qualitative changes in ability and behavior—such as learning to control robotic manipulators using code, productively recurse to subagents, use tools, or solve theory of mind tasks.
I know I’ve written a lot, so I appreciate your reading it. To sum up:
Despite intuitive links, efforts to apply statistical learning theory to deep learning have failed, and seem to face substantial difficulties
So, we have to resort to experiment where I feel this intuitive story doesn’t fit the data, and provide some challenge cases
Hyperparameter (HP) tuning in deep learning is an expensive process, prohibitively so for neural networks (NNs) with billions of parameters. We show that, in the recently discovered Maximal Update Parametrization (muP), many optimal HPs remain stable even as model size changes. This leads to a new HP tuning paradigm we call muTransfer: parametrize the target model in muP, tune the HP indirectly on a smaller model, and zero-shot transfer them to the full-sized model, i.e., without directly tuning the latter at all. We verify muTransfer on Transformer and ResNet. For example, 1) by transferring pretraining HPs from a model of 13M parameters, we outperform published numbers of BERT-large (350M parameters), with a total tuning cost equivalent to pretraining BERT-large once; 2) by transferring from 40M parameters, we outperform published numbers of the 6.7B GPT-3 model, with tuning cost only 7% of total pretraining cost. A Pytorch implementation of our technique can be found at this http URL and installable via `pip install mup`.
muP comes from a principled mathematical analysis of how different ways of scaling various architectural hyperparameters alongside model width influences activation statistics.
I was trying to make a more specific point. Let me know if you think the distinction is meaningful -
So there are lots of “semi-rigorous” successes in Deep Learning. One I understand better than muP is good old Xavier initialization. Assuming that the activations at a given layer are normally distributed, we should scale our weights like 1/sqrt(n) so the activations don’t diverge from layer to layer (since the sum of independent normals scales like sqrt(n)). This is exactly true for the first gradient step, but can become false at any later step once the weights are no longer independent and can “conspire” to blow up the activations anyway. So not a “proof” but successful in practice.
My understanding of muP is similar in that is “precise” only if certain correlations are well controlled (I’m fuzzy here). But still v. successful in practice. But the proof isn’t airtight and we still need that last step—checking “in practice”.
This is very different from the situation within statistical learning itself, which has many beautiful and unconditional results which we would very much like to port over to deep learning. My central point is that in the absence of a formal correspondence, we have to bridge the gap with evidence. That’s why the last part of my comment was some evidence I think speaks against the intuitions of statistical learning theory. I consider Xavier initialization, muP, and scaling laws, etc. as examples where this bridge was successfully crossed—but still necessary! And so we’re reduced to “arguing over evidence between paradigms” when we’d prefer to “prove results within a paradigm”
For what notion is the first problem complicated, and the second simple?
I might be out of my depth here, but—could it be that sparse parity with noise is just objectively “harder than it sounds” (because every bit of noise inverts the answer), whereas protein folding is “easier than it sounds” (because if it weren’t, evolution wouldn’t have solved it)?
Just because the log-depth xor tree is small, doesn’t mean it needs to be easy to find, if it can hide amongst vastly many others that might have generated the same evidence … which I suppose is your point. (The “function approximation” frame encourages us to look at the boolean circuit and say, “What a simple function, shouldn’t be hard to noisily approximate”, which is not exactly the right question to be asking.)
I think it’s also a question of the learning techniques used. It seems like generalizable solutions to xor involve at least one of the two following:
noticing that many of the variables in the input can be permuted without changing the output and therefore it’s only the counts of the variables that matter
noticing how the output changes or doesn’t change when you start with one input and flip the variable one at a time
But current neural networks can’t use either of these techniques, partly because it doesn’t align well with the training paradigm. They both kind of require the network to be able to pick the inputs to see the ground truth for, whereas training based on a distribution has the (input, output) pairs randomized.
To “humanize” these problems, we can:
Break intuitive permutation invariance by using different symbols in each slot. So instead of an input looking like 010110 or 101001 or 111000, it might look like A😈4z👞i or B🔥😱Z😅0 or B😈😱Z😅i.
Break the ability to notice the effects of single bitflips by just seeing random strings rather than neighboring strings.
This does a great job of importing and translating a set of intuitions from a much more established and rigorous field. However, as with all works framing deep learning as a particular instance of some well-studied problem, it’s vital to keep the context in mind:
Despite literally thousands of papers claiming to “understand deep learning” from experts in fields as various as computational complexity, compressed sensing, causal inference, and—yes—statistical learning, NO rigorous, first-principles analysis has ever computed any aspect of any deep learning model beyond toy settings. ALL published bounds are vacuous in practice.
It’s worth exploring why, despite strong results in their own settings, and despite strong “intuitive” parallels to deep learning, this remains true. The issue is that all these intuitive arguments have holes big enough to accommodate, well, the end of the world. There are several such challenges in establishing a tight correspondence between “classical machine learning” and deep learning, but I’ll focus on one that’s been the focus of considerable effort: defining simplicity.
This notion is essential. If we consider a truly arbitrary function, there is no need for a relationship between the behavior on one input and the behavior on another—the No Free Lunch Theorem. If we want our theory to have content (that is, to constrain the behavior of a Deep Learning system whatsoever) we’ll need to narrow the range of possibilities. Tools from statistical learning like the VC dimension are useless as is due to overparameterization, as you mention. We’ll need a notion of simplicity that captures what sorts of computational structures SGD finds in practice. Maybe circuit size, or minima sharpness, or noise sensitivity… - how hard could it be?
Well no one’s managed it. To help understand why, here are two (of many) barrier cases:
Sparse parity with noise. For an input bitstring x, y is defined as the xor of a fixed, small subset of indices. E.g if the indices are 1,3,9 and x is 101000001 then y is 1 xor 1 xor 1 = 1. Some small (tending to zero) measurement error is assumed. Though this problem is approximated almost perfectly by extremely small and simple boolean circuits (a log-depth tree of xor gates with inputs on the chosen subset), it is believed to require an exponential amount of computation to predict even marginally better than random! Neural networks require exponential size to learn it in practice.
Deep Learning Fails
The protein folding problem. Predict the shape of a protein from its amino acid sequence. Hundreds of scientists have spent decades scouring for regularities, and failed. Generations of supercomputers have been built to attempt to simulate the subtle evolution of molecular structure, and failed. Billions of pharmaceutical dollars were invested—hundreds of billions were on the table for success. The data is noisy and multi-modal. Protein language models learn it all anyway.
Deep Learning Succeeds!
For what notion is the first problem complicated, and the second simple?
Again, without such a notion, statistical learning theory makes no prediction whatsoever about the behavior of DL systems on new examples. If a model someday outputted a sequence of actions which caused the extinction of the human race, we couldn’t object on principle, only say “so power-seeking was simpler after all”. And even with such a notion, we’d still have to prove that Gradient Descent tends to find it in practice and a dozen other difficulties...
Without a precise mathematical framework to which we can defer, we’re left with Empirics to help us choose between a bunch of sloppy, spineless sets of intuitions. Much less pleasant. Still, here’s a few which push me towards Deep Learning as a “computationally general, pattern-finding process” rather than function approximation:
Neural networks optimized only for performance show surprising alignment with representations in the human brain, even exhibiting 1-1 matches between particular neurons in ANNs and living humans. This is an absolutely unprecedented level of predictivity, despite the models not being designed for such and taking no brain data as input.
LLM’s have been found to contain rich internal structure such as grammatical parse-trees, inference-time linear models, and world models. This sort of mechanistic picture is missing from any theory that considers only i/o
Small changes in loss (i.e. function approximation accuracy) have been associated with large qualitative changes in ability and behavior—such as learning to control robotic manipulators using code, productively recurse to subagents, use tools, or solve theory of mind tasks.
I know I’ve written a lot, so I appreciate your reading it. To sum up:
Despite intuitive links, efforts to apply statistical learning theory to deep learning have failed, and seem to face substantial difficulties
So, we have to resort to experiment where I feel this intuitive story doesn’t fit the data, and provide some challenge cases
This is false. From the abstract of Tensor Programs V: Tuning Large Neural Networks via Zero-Shot Hyperparameter Transfer
muP comes from a principled mathematical analysis of how different ways of scaling various architectural hyperparameters alongside model width influences activation statistics.
I was trying to make a more specific point. Let me know if you think the distinction is meaningful -
So there are lots of “semi-rigorous” successes in Deep Learning. One I understand better than muP is good old Xavier initialization. Assuming that the activations at a given layer are normally distributed, we should scale our weights like 1/sqrt(n) so the activations don’t diverge from layer to layer (since the sum of independent normals scales like sqrt(n)). This is exactly true for the first gradient step, but can become false at any later step once the weights are no longer independent and can “conspire” to blow up the activations anyway. So not a “proof” but successful in practice.
My understanding of muP is similar in that is “precise” only if certain correlations are well controlled (I’m fuzzy here). But still v. successful in practice. But the proof isn’t airtight and we still need that last step—checking “in practice”.
This is very different from the situation within statistical learning itself, which has many beautiful and unconditional results which we would very much like to port over to deep learning. My central point is that in the absence of a formal correspondence, we have to bridge the gap with evidence. That’s why the last part of my comment was some evidence I think speaks against the intuitions of statistical learning theory. I consider Xavier initialization, muP, and scaling laws, etc. as examples where this bridge was successfully crossed—but still necessary! And so we’re reduced to “arguing over evidence between paradigms” when we’d prefer to “prove results within a paradigm”
I might be out of my depth here, but—could it be that sparse parity with noise is just objectively “harder than it sounds” (because every bit of noise inverts the answer), whereas protein folding is “easier than it sounds” (because if it weren’t, evolution wouldn’t have solved it)?
Just because the log-depth xor tree is small, doesn’t mean it needs to be easy to find, if it can hide amongst vastly many others that might have generated the same evidence … which I suppose is your point. (The “function approximation” frame encourages us to look at the boolean circuit and say, “What a simple function, shouldn’t be hard to noisily approximate”, which is not exactly the right question to be asking.)
I think it’s also a question of the learning techniques used. It seems like generalizable solutions to xor involve at least one of the two following:
noticing that many of the variables in the input can be permuted without changing the output and therefore it’s only the counts of the variables that matter
noticing how the output changes or doesn’t change when you start with one input and flip the variable one at a time
But current neural networks can’t use either of these techniques, partly because it doesn’t align well with the training paradigm. They both kind of require the network to be able to pick the inputs to see the ground truth for, whereas training based on a distribution has the (input, output) pairs randomized.
To “humanize” these problems, we can:
Break intuitive permutation invariance by using different symbols in each slot. So instead of an input looking like 010110 or 101001 or 111000, it might look like A😈4z👞i or B🔥😱Z😅0 or B😈😱Z😅i.
Break the ability to notice the effects of single bitflips by just seeing random strings rather than neighboring strings.
This makes it intuitively much harder to me.