More Recent Progress in the Theory of Neural Networks
Thanks to Dan Roberts and Sho Yaida for comments on a draft of this post.
In this post, I would like to draw attention to the book Principles of Deep Learning Theory (PDLT), which I think represents a significant advance in our understanding of how neural networks work [1]. Among other things, this book explains how to write a closed-form formula for the function learned by a realistic, finite-width neural network at the end of training [2] to an order of approximation that suffices to describe representation learning, and how that formula can be interpreted as the solution to a regression model. This makes manifest the intuition that NNs are doing something like regression, but where they learn the features appropriate for a given dataset rather than having them be hand-engineered from the start.
I’ve condensed some main points from the 400-page book into an 8-page summary here:
Review of select results from PDLT
(Other good places to learn about the book, though perhaps with less of a focus on AI-safety-relevant parts, include this series of five lectures given by the authors at a deep learning summer school or this one-hour lecture for a non-expert audience.)
For those who have been following the discussions of ML theory on this forum, the method used in the book is to go to the next-to-leading order in a 1/width expansion. It thus builds on recent studies of infinitely wide NNs that were reviewed in the AF post Recent Progress in the Theory of Neural Networks [3]. However, by going beyond the leading order, the authors of PDLT are able to get around a key qualitative shortcoming of the earlier work in that infinitely wide NNs can’t learn features. The next-to-leading order formula also introduces a sum over many steps of gradient descent, getting around an objection [4] that the NTK/infinite width limit may not be applicable to realistic models since in that limit, we can land on the fully trained model after just one fine-tuned training step.
I think that this work could have significant implications for AGI forecasting and safety (via interpretability), and deserves to be better appreciated in this community. For example,
In AGI forecasting, an important open question is whether the strong scaling hypothesis holds for any modern architectures. (For example, the forecasts in Ajeya Cotra’s Bio-Anchors report are conditioned on assuming that 2020 algorithms can scale to TAI.) A longstanding challenge for this field is that as long as we treat neural networks as black boxes or random program search, it’s hard to reason about this question in a principled way. But I think that by identifying a space of functions that realistic NNs end up learning in practice (<< the space of all neural networks with finely-tuned weights!), the approach of PDLT gives us a way to start to reason about it. For example, despite the existence of the universal approximation theorem, I think the results of PDLT can be used to rule out the (strawmannish) hypothesis that feedforward MLPs can scale to AGI (see my review of the Bio-Anchors report for more on this point). As such, it could be really interesting to generalize PDLT to other architectures.
In mechanistic interpretability, a basic open question is what the fundamental degrees of freedom are that we should be trying to interpret. A lot of work has been done under the assumption that we should look at the activations of individual neurons, but there’s naively no reason that semantically meaningful properties of a dataset must align with individual neurons after training, and even some interesting counterexamples [5]. By finding a dual description of a trained NN as a trained regression model, PDLT seems to hint that a (related, but) different set of degrees of freedom—the effective features (32) in the above-linked note—may be more natural objects to look at. It would be really interesting to see if this turns out to be the case [6].
More generally, a dream for interpretability research would be if we could reverse-engineer our future AI systems into human-understandable code. If we take this dream seriously, it may be helpful to split it into two parts: first understanding what “programming language” an architecture + learning algorithm will end up using at the end of training, and then what “program” a particular training regimen will lead to in that language [7]. It seems to me that by focusing on specific trained models, most interpretability research discussed here is of the second type. But by constructing an effective theory for an entire class of architecture that’s agnostic to the choice of dataset, PDLT is a rare example of the first type. So it could be not only useful but also totally complementary to other agendas to try to develop it further and/or generalize it to new architectures as they come along.
- ^
See here for an earlier, though shorter discussion of this book on LW.
- ^
As a function of its architecture + weight initialization + training set + learning algorithm. The formula is (.154) on page 1 of the note that I link to below.
- ^
See also yesterday’s post Neural Tangent Kernel Distillation.
- ^
As raised e.g. by Paul Christiano here.
- ^
See the recent Anthropic paper Toy Models of Superposition for a discussion of this and related issues.
- ^
For example, one could try to see which dataset examples or synthetic data would maximally activate the effective features, generalizing the experiments done on neurons in the Circuits thread. (However, a caveat for this project idea is that there are a huge number of effective features in PDLT, scaling as the number of weights instead of the number of neurons! So one might have to start with a really small toy model, or be clever about picking a subset / combination of features to visualize.)
It could also be interesting to understand how the PDLT effective theory fits conceptually with other ideas from the interpretability and broader “science of ML” literature. For example, how should we think of compositional circuits in the dual frame, which seems to put all effective features on the same footing instead of having some be built from others? (Perhaps this isn’t a meaningful question to ask since the compositional circuits in vision models are themselves a fuzzy emergent description. But in that case, can we generalize PDLT to transformers, and then understand crisp emergent circuits like modular arithmetic circuits in the dual frame?) Or since the dual model has a huge number of effective features, might there be a lottery ticket hypothesis at the level of the features?
- ^
Or to put it another way, if we want to understand what cognition a trained AI system is performing at inference time, there’s both a kinematic aspect of claiming that some degrees of freedom in the system (or the way that they activate or something) approximately encode some human-understandable concepts in a Platonic latent space (e.g. paperclips), and a dynamical aspect of how those concepts get put together (e.g. into a plan to turn us into paperclips). The space of allowed dynamics is what I’m calling the “programming language” per Chris’s metaphor.
- A review of the Bio-Anchors report by 3 Oct 2022 10:27 UTC; 45 points) (
- Properties of current AIs and some predictions of the evolution of AI from the perspective of scale-free theories of agency and regulative development by 20 Dec 2022 17:13 UTC; 33 points) (
- 29 Dec 2022 11:48 UTC; 1 point) 's comment on How evolutionary lineages of LLMs can plan their own future and act on these plans by (
- 26 May 2023 8:40 UTC; 1 point) 's comment on ‘Fundamental’ vs ‘applied’ mechanistic interpretability research by (
Sho and I want to thank jylin04 for this really nice post and endorse the distillation of our key results in her 8-page summary. We also agree that it would be interesting to make further connections between our work—in particular the effective theory framework—and interpretability, and we’d be really glad to explore and discuss that further.
The book’s results hold for a specific kind of neural network training parameterisation, the “NTK parametrisation”, which has been argued (convincingly, to me) to be rather suboptimal. With different parametrisation schemes, neural networks learn features even in the infinite width limit.
You can show that neural network parametrisations can essentially be classified into those that will learn features in the infinite width limit, and those that will converge to some trivial kernel. One can then derive a “maximal update parametrisation”, in which infinite width networks actually train better than finite width networks (including finite width networks using NTK parametrisation, IIRC).
Maximal update parametrisation also allows you to derive correct hyperparameters for a model. You can just find the right hyperparameters for a small model, which is far less costly, then scale up. If you used maximal update parametrisation, the results will still be right for the large model.
So I wouldn’t use these results as evidence for or against any scaling hypotheses, and I somewhat doubt they make a good jumping off point for figuring out how to disentangle a neural network into parts. I don’t think you’re really doing anything more fundamental than investigating how a certain incorrect choice of parameter scaling harms you more the smaller you make the depth/width ratio of your network here.
For more on this, I recommend the Tensor programming papers and related talks by Greg Yang.
TL;DR: I think these results only really concern networks trained in a way that’s shooting yourself in the foot.
Thank you for the comment! Let me reply to your specific points.
First and TL; DR, in terms of whether NTK parameterization is “right” or “wrong” is perhaps an issue of prescriptivism vs. descriptivism: regardless of which one is “better”, the NTK parameterization is (close to what is) commonly used in practice, and so if you’re interested in modeling what practitioners do, it’s a very useful setting to study. Additionally, one disadvantage of maximal update parameterization from the point of view of interpretability is that it’s in the strong-coupling regime, and many of the nice tools we use in our book, e.g., to write down the solution at the end of training, cannot be applied. So perhaps if your interest is safety, you’d be shooting yourself in the foot if you use maximal update parameterization! :)
Second, it is a common misconception that the NTK parameterization cannot learn features and that maximal update parameterization is the only parameterization that learns features. As discussed in the post above, all networks in practice have finite width; the infinite-width limit is a formal idealization. At finite width, either parameterization learns features. Moreover, in the formal infinite-width limit, it is true that *infinite-width with fixed depth* doesn’t learn features, but you can also take a limit that scales up both depth and width together where NTK parameterization learns features. Indeed, one of the main results of the book is to say that, for NTK parameterization, the depth-to-width aspect ratio is the key hyperparameter that controls the theory describing how realistic networks behave.
Third, the scaling up of hyperparameters is an aspect that follows from the understanding of either parameterization, NTK or maximal update; a benefit of this kind of the theory, from the practical perspective, is certainly learning how to correctly scale up to larger models.
Fourth, I agree that maximal update parameterization is also interesting to study, especially so if it becomes dominant among practitioners.
Finally, perhaps it’s worth adding that the other author of the book (Sho) is posting a paper next week on relating these two parameterizations. There, he finds that an entire one-parameter family worth of parametrizations—interpolating between NTK parametrization and maximal update parametrization—can learn features, if depth is scaled properly with width. (Edit: here’s a link, https://arxiv.org/abs/2210.04909) Curiously, as mentioned in the first point above, the maximal update parametrization is in the strong-coupling regime, making it difficult to use theory to interpret. In terms of which parameterization is prescriptively better from a capabilities perspective, I think that remains an empirical question...
Aren’t Standard Parametrisation and other parametrisations with a kernel limit commonly used mostly in cases where you’re far away from reaching the depth-to-width≈0 limit, so expansions like the one derived for the NTK parametrisation aren’t very predictive anymore, unless you calculate infeasibly many terms in the expensive perturbative series?
As far as I’m aware, when you’re training really big models where the limit behaviour matters, you use parametrisations that don’t get you too close to a kernel limit in the regime you’re dealing with. Am I mistaken about that?
As for NTK being more predictable and therefore safer, it was my impression that it’s more predictive the closer you are to the kernel limit, that is, the further away you are from doing the kind of representational learning AI Safety researchers like me are worried about. As I leave that limit behind, I’ve got to take into account ever higher order terms in your expansion, as I understand it. To me, that seems like the system is just getting more predictive in proportion to how much I’m crippling its learning capabilities.
Yes, of course NTK parametrisation and other parametrisations with a kernel limit can still learn features at finite width, I never doubted that. But it generally seems like adding more parameters means your system should work better, not worse, and if it’s not doing that, it seems like the default assumption should be that you’re screwing up. If it was the case that there’s no parametrisation in which you can avoid converging to a trivial limit as you heap on more parameters onto the width of an MLP, that would be one thing, and I think it’d mean we’d have learned something fundamental and significant about MLP architectures. But if it’s only a certain class of parametrisations, and other parametrisations seem to deal with you piling on more parameters just fine, both in theory and in practice, my conclusion would be that what you’re seeing is just a result of choosing a parametrisation that doesn’t handle your limit gracefully. Specifically, as I understood it, Standard parametrisation for example just doesn’t let enough gradient reach the layers before the final layer if the network gets too large. As the network keeps getting wider, those layers are increasingly starved of updates until they just stop doing anything altogether, resulting in you training what’s basically a one layer network in disguise. So you get a kernel limit.
TL;DR: Sure you can use NTK parametrisation for things, but it’s my impression that it does a good job precisely in those cases where you stay far away from the depth-to-width≈0 limit regime in which the perturbative expansion is a useful description.
Thank you for the discussion!
Let us start by stressing that, of course, the maximal-update parametrization is definitely an intriguing recent development, and it would be very interesting to find tools to be able to understand the strongly-coupled regime in which it resides.
Now, it seems like there are two different issues tangled in this discussion: (i) is one parameterization “better” than another in practice?; and (ii) is our effective theory analysis useful in practically interesting regimes?
The first item is perhaps more an empirical question, whose answer will likely emerge in coming years. But, even if maximal-update parametrization turns out to be universally better for every task, its strongly-coupled nature makes it very difficult to analyze, which perhaps makes it more problematic from a safety/interpretability perspective.
For the second item, we hope we will address concerns in the details of our reply below.
We’d like to also emphasize that, even if you are against NTK parameterization in practice and don’t think it’s relevant at all—a position we don’t hold, but maybe one might—perhaps it’s still worth pointing out that our work provides a simple solvable model of representation learning from which we might learn some general principles that may be applicable to safety and interpretability.
With those said, let us respond to your comments point by point.
We aren’t sure if that’s accurate: empirically, as nicely described in Jennifer’s 8-page summary (in Sec. 1.5), many practical networks—from a simple MLP to the not-very-simple GPT-3 -- seem to perform well in a regime where the depth-to-width aspect ratio is small (like 0.01 or at most 0.1). So, the leading-order perturbative description would be fairly accurate for describing these practically-useful networks.
Moreover, one of the takeaways from “effective theory” descriptions is that we understand the truncation error: in particular, the errors from truncation will be of order (depth-to-width aspect ratio)^2. So this means we can estimate what we would miss by truncating the series and learn that sometimes—if not most of the time—we really don’t have to compute these extra terms.
It is true that decreasing the depth-to-width aspect ratio reduces the representation-learning capability of the network and—to the extent that representation learning is useful for the task—doing so would degrade the performance. But (i) let us reiterate that, as alluded to above, empirically networks seem to operate well in the perturbative regime where the aspect ratio is small and (ii) the converse is not true (i.e., it is not beneficial to keep increasing the aspect ratio indefinitely), as we illustrate in responding to the following point.
Actually, that last point is not always the case. One of the results from our book is that while increasing the depth-to-width ratio leads to more representation learning, it also leads to more fluctuations in gradients from random seed to random seed. Thus, the deeper your network is for fixed width, the harder it is to train, in the sense that different realizations will not only behave differently, but also will likely not be critical (i.e., it will not be on what is sometimes referred to as the “edge of chaos” and it will suffer from exploding/vanishing gradients). And this last observation is true for both the NTK parametrization and maximal-update parametrization, so by your logic, we would be screwing up no matter which parametrization we use. :)
As it turns out, this tradeoff between the benefit of representation learning and the cost of seed-to-seed fluctuations leads to the concept of the optimal aspect ratio where networks should perform the best. Empirical results indirectly indicate that this optimal aspect ratio may be in the perturbative regime; in the Appendix of our book, we also did a calculation using tools from information theory that gives evidence that the optimal depth-to-width ratio is in the perturbative regime.
We don’t think this is the case. Both NTK and maximal-update parametrizations can avoid converging to kernel limits and can allow features to evolve: for the NTK parametrization, we need to keep increasing the depth in proportion to the width; for the maximal-update parametrization, we need to keep the depth fixed while increasing the width.
Sho and Dan
Hi jylin04. Fantastic post! It touches on many more aspects of interpretability than my post about the book. I also enjoyed your summary PDF!
I’d love to contribute to any theory work in this direction, if I can. Right now I’m stuck around p. 93 of the book. (I’ve read everything, but I’m now trying to re-derive the equations and have trouble figuring out where a certain term goes. I am also building a Mathematica package that takes care of some of the more tedious parts of the calculations.) Maybe we could get in touch?