Dialogue introduction to Singular Learning Theory

Alice: A lot of people are talking about Singular Learning Theory. Do you know what it is?

Bob: I do. (pause) Kind of.

Alice: Well, I don’t. Explanation time?

Bob: Uh, I’m not really an expert on it. You know, there’s a lot of materials out there that--

Alice: that I realistically won’t ever actually look at. Or, I’ve looked at them a little, but I still have basically no idea what’s going on. Maybe if I watched a dozen hours of introductory lectures I’d start to understand it, but that’s not currently happening.

What I really want is a short overview of what’s going on. That’s self-contained. And easy to follow. Aimed at a non-expert. And which perfectly answers any questions I might have. So, I thought I’d ask you!

Bob: Sorry, I’m actually really not--

Alice: Pleeeease?

[pause]

Bob: Ah, fine, I’ll try.

So, you might have heard of ML models being hard to interpret. Singular Learning Theory (SLT) is an approach for understanding models better. Or, that’s one motivation, at least.

Alice: And how’s this different from a trillion other approaches to understanding AI?

Bob: A core perspective of SLT is studying how the model develops during training. Contrast this to, say, mechanistic interpretability, which mostly looks at the fully trained model. SLT is also more concerned about higher level properties.

As a half-baked analogue, you can imagine two approaches to studying how humans work: You could just open up a human and see what’s inside. Or, you could notice that, hey, you have these babies, which grow up into children, go through puberty, et cetera, what’s up with that? What are the different stages of development? Where do babies come from? And SLT is more like the second approach.

Alice: This makes sense as a strategy, but I strongly suspect you don’t currently know what an LLM’s puberty looks like.

Bob: (laughs) No, not yet.

Alice: So what do you actually have?

Bob: The SLT people have some quite solid theory, and some empirical work building on top of that. Maybe I’ll start from the theory, and then cover some of the empirical work.

Alice: (nods)

I. Theoretical foundations

Bob: So, as you know, nowadays the big models are trained with gradient descent. As you also know, there’s more to AI than gradient descent. And for a moment we’ll be looking at the Bayesian setting, not gradient descent.

Alice: Elaborate on “Bayesian setting”?

Bob: Imagine a standard deep learning setup, where you want your neural network to classify images, predict text or whatever. You want to find parameters for your network so that it has good performance. What do you do?

The gradient descent approach is: Randomly initialize the parameters, then slightly tweak them on training examples in the direction of better performance. After a while your model is probably decent.

The Bayesian approach is: Consider all possible settings of the parameters. Assign some prior to them. For each model, check how well they predict the correct labels on some training examples. Perform a Bayesian update on the prior. Then sample a model from the posterior. With lots of data you will probably obtain a decent model.

Alice: Wait, isn’t the Bayesian approach very expensive computationally?

Bob: Totally! Or, if your network has 7 parameters, you can pull it off. If it has 7 billion, then no. There are way too many models, we can’t do the updating, not even approximately.

Nevertheless, we’ll look at the Bayesian setting—it’s theoretically much cleaner and easier to analyze. So forget about computational costs for a moment.

Alice: Will the theoretical results also apply to gradient descent and real ML models, or be completely detached from practice?

Bob: (winks)

Alice: You know what, maybe I’ll just let you talk.

Bob: There’s a really fascinating phenomenon in the Bayesian setting: you can have abrupt “jumps” in the model you sample (which people call phase changes). Let me explain.

Suppose you do Bayesian updates on, say, 10000 data points. Maybe your posterior then looks like this:

And you might think: probably with more data there will just be convergence around the optimal model at . You feed in 500 data points more:

Huh, what’s that? You feed in another 500 data points:

You don’t necessarily get gradual convergence around the “best model”! Instead, you can have an abrupt jump: over the course of relatively few more examples, your posterior has totally shifted from one place to another, and the types of models you get by sampling the posterior would be completely different for 11000 data points from the ones with 10000 points.

Alice: Wait, are these real graphs?

Bob: Real in the sense of being the result of Bayesian updates, yes, but I specifically crafted a loss function to demonstrate my point, and it’s all very toy.

There are more natural examples when you have more than one parameter, though, and the naive view of gradual convergence is definitely false.

Alice: Why does this happen?

Bob: Loosely, there are two things that determine the size of a bump (assuming it’s sensible to decompose the posterior/​model space into “bumps”):

  1. The height of the bump: how well does the best model of the bump perform?

  2. The width of the bump: once you stray away from the local best model, how quickly does performance drop?

At one extreme, you have an excellent model, but the parameters have to be very precisely right, or it breaks down fast. This would be a narrow, tall spike in likelihood.

At another extreme, you have a mediocre model, but the parameters are robust to small changes. This would be a wide, low bump in likelihood.

Alice: I would expect that when you increase the amount of data, you start moving towards the “excellent-but-fragile” models.

Bob: That’s right. Intuitively, at the beginning the prior favors wide, low bumps: if you have no data, you cannot locate the good-but-specific models. But the Bayesian updates favor the good-but-specific models, and eventually they start to take over. This can happen rather quickly—the posterior bumps don’t have to be of comparable sizes for long.

Like all things in life, the performance-specificity tradeoff is a spectrum, and you can have multiple jumps from one bump to another as you shove in more data.

Alice: This talk about “fragility” or “specificity” feels a bit vague to me, though. Care to clarify?

Bob: Sure. This is slightly more technical, so buckle up.

Setting the prior aside, we are interested only in models’ predictive performance, i.e. the average log-probability given to the correct labels, i.e. the loss function. (This corresponds to the likelihood factor for Bayesian updates.)

Consider a model that’s locally optimal for this loss function. Here, let me sketch a couple of plots:

Here lower is better performance, so our bumps have turned into basins.

What I mean by “fragility” or “specificity” is roughly: how steeply does loss increase as we move away from the local optimum?

More precisely: If the performance of the model is , we look at the models which have loss at most for some small error (and which are part of the same basin), and specifically their volume—how “many” such models are there?

Alice: To check I understand: if I draw the parameter-axis here, and the threshold for additional error here, the parameters that result in only worse loss are the red segment here. And “volume” is just length in this case—the length of the red segment.

Bob: Yep. Similarly with two parameters, we look at the parameters for which we have only slightly worse loss—the red region here—and it’s “volume”, in this case area.

And the key question is: how does the length/​area/​volume behave for small ?

Alice: Oh, I think I can solve this! I’ve seen this type of arguments a couple of times before:

If one considers the partial derivatives at the local optima, then the first derivatives are zero, and the second derivatives are non-negative. It’s very unlikely that they are exactly zero, so assume that they are positive--

Bob: (smiles)

Alice: --and thus the basin can be locally modeled as a high-dimensional parabola. That is, it takes the shape

,

where are positive constants. The constants don’t really matter, they just stretch the picture, so I’m gonna assume

Alice: How “many” values of are there such that ?

Around , give or take a constant factor.

Alice: (muttering to herself) Indeed, we must have , so any single variable has at most values, and thus they in total have at most values. On the other hand, if we have , then , so there are at least possible values.

Bob: Good, great. Indeed, when the second derivatives don’t vanish, is correct.

Alice: Yeah. Probably a similar argument works for the case where some second derivatives are zero, but that case should be really unlikely and so doesn’t matter.

Bob: (smiles widely)

Alice: What?

Bob: Do you know what the “singular” in “Singular Learning Theory” stands for?

Alice: Uh oh.

Bob: In general, basins are nasty. The high-dimensional parabola approximations are utterly false. This isn’t some pedantic nitpick—it’s just a completely wrong picture.

(I kind of let you astray with the pictures I drew above, sorry. )

To illustrate, here’s just one relatively benign example from two dimensions:

And it gets worse when you have billions of parameters. Welcome to deep learning.

Alice: Let me guess: basins of different shapes have different rates of volume-expansion?

Bob: Spot on. For the high-dimensional parabola, the volume-expansion-exponent was , but in general it can be less than that. If the exponent is smaller than - when the model is singular—the basin has more almost-equally-performant models. This corresponds to a larger Bayesian posterior. Exponentially so, due to the nature of these things.

This is the key insight of singular learning theory: singular models really matter.

And indeed, for the Bayesian setting we have hard proof for this. You really are selecting models based on both predictive accuracy and volume-expansion-exponent (better known as the learning coefficient). If you hear people talking about “Watanabe’s free energy formula”, it’s precisely about this.

II. Practical side

Alice: While what you say makes sense, and has new points I hadn’t thought about, I can’t help but think: this is not SGD. How useful is this in practice, really?

Bob: Tricky question.

Clearly, gradient descent is a very different process. Most importantly, it always performs local updates on the model, doesn’t explore the whole parameter space, could be “blocked” from parts of the space by local barriers, and so on. And maybe this matters quite a lot.

On the other hand, some of the insights of SLT do carry over to gradient descent. Most compellingly, there’s some empirical work demonstrating it’s usefulness. There are also some theoretical arguments about how simplified models of SGD correspond to Bayesian learning. And on a general level, given that SLT is the right way of thinking about the Bayesian setting, it’s reasonable to think about it in the case of SGD as well.

Alice: Say more about the empirical work.

Bob: Applying these methods to deep learning has only started very recently. Which is to say: there’s a lot to be done.

In any case, let me talk about a couple of articles I’ve liked.

There’s this paper called “Dynamical versus Bayesian Phase Transitions in a Toy Model of Superposition”. They look at a toy learning problem—how to store many features in a small number of dimensions—and how both Bayesian methods and stochastic gradient descent learn.

They find that the SLT picture gives non-trivial insight to SGD: The learning trajectory has a couple of sharp drops in loss, accompanied by sharp changes in the (local) learning coefficient. Corresponding phase changes occur when using Bayesian learning.

Alice: Wait, are you saying the phase changes are the same for Bayesian learning and SGD?

Bob: Not quite: in the paper they are unable to find all of the Bayesian phase changes in SGD. This makes sense: the SGD trajectory is local, after all, and doesn’t look at the whole parameter space. Thus, SGD might be “missing” some phase changes that Bayesian learning has.

However, they do hypothesize that any SGD phase change can be found from Bayesian learning as well. Any time you see a phase change in SGD, there’s—the hypothesis goes—a “Bayesian reason” for it.

Alice: What else you’ve got?

Bob: There’s a post about the learning coefficient in a modular addition network.

They demonstrate that networks which memorize data vs. which generalize have vastly different learning coefficients. Thus, you can get information about generalization behavior without actually evaluating the model on new inputs!

They also verify that the learning coefficient approximation methods work well for medium-sized networks.

Alice: I was about to ask: These examples seem to be about rather small networks and toy settings. Is there anything on larger models?

Bob: See “The Developmental Landscape of In-Context Learning”. They train a transformer with 3 million parameters on Internet text.

They approximate the local learning coefficient throughout training and, using that and other methods, are able to identify discrete phases in the language model’s development. These phases include things like learning frequencies of bi-grams and forming induction heads.

Alice: That’s… actually pretty compelling. Anything with even larger models?

Bob: Not yet, as far as I know. I hope there will be!

Alice: So do I—maybe we’ll soon identify the puberty stage of LLMs.

Bob: Yes, that.

Alice: I still have a few question. Isn’t the learning coefficient only meaningful for local optima—but presumably we can’t find local optima of real life big models? And I’m still a bit confused about interpreting the learning coefficient: sure, we can plot the learning coefficient during training and notice something’s changed, but what then? Also, does it require much additional compute? Oh, and about the applicability to SGD, how sure--

Bob: (hastily) Ah, yeah, sorry, I’d love to answer your questions, but I have to go now. Maybe if you have further questions, you could ask people in the comment section or elsewhere.

Alice: Right, I’ll do that. Thanks for the explanation!