we’re going to need to connect our understanding of neural networks to our understanding of the real world
The NTK and related theories aim to go from “SGD finds a giant blob of parameters that performs well on the data for some reason” to “SGD finds a solution with such-and-such clean mathematical characterization”. To fully explain the success of deep learning you do then have to relate the clean mathematical characterization to the real world, but I think this can be done separately to some extent and is less of a bottleneck on progress. My #2 use case for interpretability would be doing stuff like this—basically conceptual/experimental investigation of the types of solutions favored by a given mathematical theory, with the goal of obtaining a high-level story about “why it works in the real world”. Plus attempts to carry out alignment/interpretability/ELK tasks in the simplified setting.
This is the sort of thing which I can much more easily imagine leading to alignment breakthroughs
Hmm, it’s been a while since I looked at this paper but if I recall it doesn’t really try to make any specific predictions about the inductive bias of neural nets in practice, it’s more like a series of suggestive analogies. That’s fine, but I think that sort of thing is more likely to be productive if guided by a more detailed theory.
I can’t speak for Richard, but I think I have a similar issue with NTK and adjacent theory as it currently stands (beyond the usual issues). I’m significantly more confident in a theory of deep learning if it cleanly and consistently explains (or better yet, predicts) unexpected empirical phenomena. The one that sticks out most prominently in my mind, that we see constantly in interpretability, is this strange correspondence between the algorithmic “structure” we find in trained models (both ML and biological!) and “structure” in the data generating process.
That training on Othello move sequences gets you an algorithmic model of the game itself is surprising from most current theoretical perspectives! So in that sense I might be suspicious of a theory of deep learning that fails to “connect our understanding of neural networks to our understanding of the real world”. This is the single most striking thing to come out of interpretability, in my opinion, and I’m worried about a “deep learning theory of everything” if it doesn’t address this head on.
That said, NTK doesn’t promise to be a theory of everything, so I don’t mean to hold it to an unreasonable standard. It does what it says on the tin! I just don’t think it’s explained a lot of the remaining questions I have. I don’t think we’re in a situation where “we can explain 80% of a given model’s behavior with the NTK” or similar. And this is relevant for e.g. studying inductive biases, as you mentioned.
But I strong upvoted your comment, because I do think deep learning theory can fill this gap—I’m personally trying to work in this area. There are some tractable-looking directions here, and people shouldn’t neglect them!
So although I don’t think the NTK can be a final answer, I do like the idea of studying it in more depth—it provides a feature-learning-free baseline against which we can compare actual neural networks and other potential ‘grand theories’. Exactly which phenomena can we not explain with the NTK, and which theory best predicts them?
Strong upvote to Zach’s comment, it basically encapsulates my view (except that I don’t know what the “tractable-looking directions” he mentions are—Zach, can you elaborate?)
Exactly which phenomena can we not explain with the NTK
I’d turn that around: is there any explanation of why LLMs can do real-world task X and not real-world task Y that appeals to NTKs? (Not a rhetorical question: there may well be, I just haven’t seen one.)
Yeah, I can expand on that—this is obviously going be fairly opinionated, but there are a few things I’m excited about in this direction.
The first thing that comes to mind here is singular learning theory. I think all of my thoughts on DL theory are fairly strongly influenced by it at this point. It definitely doesn’t have all the answers at the moment, but it’s the single largest theory I’ve found that makes deep learning phenomena substantially “less surprising” (bonus points for these ideas preceding deep learning). For instance, one of the first things that SLT tells you is that the effective parameter count (RLCT) of your model can vary depending on the training distribution, allowing it to basically do internal model selection—the absence of bias-variance tradeoff, and the success of overparameterized models, aren’t surprising when you internalize this. The “connection to real world structure” aspect hasn’t been fully developed here, but it seems heavily suggested by the framework, in multiple ways—for instance, hierarchical statistical models are naturally singular statistical models, and the hierarchical structure is reflected in the singularities. (See also Tom Waring’s thesis).
Outside of SLT, there’s a few other areas I’m excited about—I’ll highlight just one. You mentioned Lin, Tegmark, and Rolnick—the broader literature on depth separations and the curse of dimensionality seems quite important. The approximation abilities of NNs are usually glossed over with universal approximation arguments, but this can’t be enough—for generic Lipschitz functions, universal approximation takes exponentially many parameters in the input dimension (this is a provable lower bound). So there has to be something special about the functions we care about in the real world. See this section of my post for more information. I’d highlight Poggio et al. here, which is the paper in the literature closest to my current view on this.
This isn’t a complete list, even of theoretical areas that I think could specifically help address the “real world structure” connection, but these are the two I’d feel bad not mentioning. This doesn’t include some of the more empirical findings in science of DL that I think are relevant, like simplicity bias, mode connectivity, grokking, etc. Or work outside DL that could be helpful to draw on, like Boolean circuit complexity, algorithmic information theory, natural abstractions, etc.
Agreed—that alone isn’t particularly much, just one of the easier things to express succinctly. (Though the fact that this predates deep learning does seem significant to me. And the fact that SLT can delineate precisely where statistical learning theory went wrong here seems important too.)
Another is that can explain phenomena like phase transitions, as observed in e.g. toy models of superposition, at a quantitative level. There’s also been a substantial chunk of non-SLT ML literature that has independently rediscovered small pieces of SLT, like failures of information geometry, importance of parameter degeneracies, etc. More speculatively, but what excites me most, is that empirical phenomena like grokking, mode connectivity, and circuits seem to intuitively fit in SLT nicely, though this hasn’t been demonstrated rigorously yet.
any explanation of why LLMs can do real-world task X and not real-world task Y that appeals to NTKs?
I don’t think there are any. Of course much the same could be said of other deep learning theories and most(all?) interpretability work. The difference, as far as I can tell, is that there is a clear pathway to getting such explanations from the NTK: you’d want to do a spectral analysis of the sorts of functions learnable by transformer-NTKs. It’s just that nobody has bothered to do this! That’s why I think this line of research is neglected relative to interpretability or developing a new theoretical analysis of deep learning. Another obvious thing to try: NTKs often empirically perform comparably well to finite networks, but are usually are a few percentage points worse in accuracy. Can we say anything about the examples where the NTK fails? Do they particularly depend on ‘feature learning’? I think NTKs are a good compliment to mechinterp in this regard, since they treat the weights at each neuron as independent of all others, so they provide a good indicator of exactly which examples may require interacting ‘circuits’ to be correctly classified.
What is the work that finds the algorithmic model of the game itself for Othello? I’m aware of (but not familiar with) some interpretability work on Othello-GPT (Neel Nanda’s and Kenneth Li), but thought it was just about board state representations.
Yeah, that was what I was referring to. Maybe “algorithmic model” isn’t the most precise—what we know is that the NN has an internal model of the board state that’s causal (i.e. the NN actually uses it to make predictions, as verified by interventions). Theoretically it could just be forming this internal model via a big lookup table / function approximation, rather than via a more sophisticated algorithm. Though we’ve seen from modular addition work, transformer induction heads, etc that at least some of the time NNs learn genuine algorithms.
I think the core surprising thing is the fact that the model learns a representation of the board state. The causal / linear probe parts are there to ensure that you’ve defined “learns a representation of the board state” correctly—otherwise the probe could just be computing the board state itself, without that knowledge being used in the original model.
This is surprising to some older theories like statistical learning, because the model is usually treated as effectively a black box function approximator. It’s also surprising to theories like NTK, mean-field, and tensor programs, because they view model activations as IID samples from a single-neuron probability distribution—but you can’t reconstruct the board state via a permutation-invariant linear probe. The question of “which neuron is which” actually matters, so this form of feature learning is beyond them. (Though there may be e.g. perturbative modifications to these theories to allow this in a limited way).
they view model activations as IID samples from a single-neuron probability distribution—but you can’t reconstruct the board state via a permutation-invariant linear probe
Permutation-invariance isn’t the reason that this should be surprising. Yes, the NTK views neurons as being drawn from an IID distribution, but once they have been so drawn, you can linearly probe them as independent units. As an example, imagine that our input space consisted of five pixels, and at initialization neurons were randomly sensitive to one of the pixels. You would easily be able to construct linear probes sensitive to individual pixels even though the distribution over neurons is invariant over all the pixels.
The reason the Othello result is surprising to the NTK is that neurons implementing an “Othello board state detector” would be vanishingly rare in the initial distribution, and the NTK thinks that the neuron function distribution does not change during training.
The reason the Othello result is surprising to the NTK is that neurons implementing an “Othello board state detector” would be vanishingly rare in the initial distribution, and the NTK thinks that the neuron function distribution does not change during training.
Yeah, that’s probably the best way to explain why this is surprising from the NTK perspective. I was trying to include mean-field and tensor programs as well (where that explanation doesn’t work anymore).
As an example, imagine that our input space consisted of five pixels, and at initialization neurons were randomly sensitive to one of the pixels. You would easily be able to construct linear probes sensitive to individual pixels even though the distribution over neurons is invariant over all the pixels.
Yeah, this is a good point. What I meant to specify wasn’t that you can’t recover any permutation-sensitive data at all (trivially, you can recover data about the input), but that any learned structures must be invariant to neuron permutation. (Though I’m feeling sketchy about the details of this claim). For the case of NTK, this is sort of trivial, since (as you pointed out) it doesn’t really learn features anyway.
By the way, there are actually two separate problems that come from the IID assumption: the “independent” part, and the “identically-distributed” part. For space I only really mentioned the second one. But even if you deal with the identically distributed assumption, the independence assumption still causes problems.This prevents a lot of structure from being representable—for example, a layer where “at most two neurons are activated on any input from some set” can’t be represented with independently distributed neurons. More generally a lot of circuit-style constructions require this joint structure. IMO this is actually the more fundamental limitation, though takes longer to dig into.
I was trying to include mean-field and tensor programs as well
but that any learned structures must be invariant to neuron permutation. (Though I’m feeling sketchy about the details of this claim)
The same argument applies—if the distribution of intermediate neurons shifts so that Othello-board-state-detectors have a reasonably high probability of being instantiated, it will be possible to construct a linear probe detecting this, regardless of the permutation-invariance of the distribution.
the independence assumption still causes problems
This is a more reasonable objection(although actually, I’m not sure if independence does hold in the tensor programs framework—probably?)
if the distribution of intermediate neurons shifts so that Othello-board-state-detectors have a reasonably high probability of being instantiated
Yeah, this “if” was the part I was claiming permutation invariance causes problems for—that identically distributed neurons probably couldn’t express something as complicated as a board-state-detector. As soon as that’s true (plus assuming the board-state-detector is implemented linearly), agreed, you can recover it with a linear probe regardless of permutation-invariance.
This is a more reasonable objection(although actually, I’m not sure if independence does hold in the tensor programs framework—probably?)
I probably should’ve just gone with that one, since the independence barrier is the one I usually think about, and harder to get around (related to non-free-field theories, perturbation theory, etc).
My impression from reading through one of the tensor program papers a while back was that it still makes the IID assumption, but there could be some subtlety about that I missed.
The NTK and related theories aim to go from “SGD finds a giant blob of parameters that performs well on the data for some reason” to “SGD finds a solution with such-and-such clean mathematical characterization”. To fully explain the success of deep learning you do then have to relate the clean mathematical characterization to the real world, but I think this can be done separately to some extent and is less of a bottleneck on progress. My #2 use case for interpretability would be doing stuff like this—basically conceptual/experimental investigation of the types of solutions favored by a given mathematical theory, with the goal of obtaining a high-level story about “why it works in the real world”. Plus attempts to carry out alignment/interpretability/ELK tasks in the simplified setting.
Hmm, it’s been a while since I looked at this paper but if I recall it doesn’t really try to make any specific predictions about the inductive bias of neural nets in practice, it’s more like a series of suggestive analogies. That’s fine, but I think that sort of thing is more likely to be productive if guided by a more detailed theory.
I can’t speak for Richard, but I think I have a similar issue with NTK and adjacent theory as it currently stands (beyond the usual issues). I’m significantly more confident in a theory of deep learning if it cleanly and consistently explains (or better yet, predicts) unexpected empirical phenomena. The one that sticks out most prominently in my mind, that we see constantly in interpretability, is this strange correspondence between the algorithmic “structure” we find in trained models (both ML and biological!) and “structure” in the data generating process.
That training on Othello move sequences gets you an algorithmic model of the game itself is surprising from most current theoretical perspectives! So in that sense I might be suspicious of a theory of deep learning that fails to “connect our understanding of neural networks to our understanding of the real world”. This is the single most striking thing to come out of interpretability, in my opinion, and I’m worried about a “deep learning theory of everything” if it doesn’t address this head on.
That said, NTK doesn’t promise to be a theory of everything, so I don’t mean to hold it to an unreasonable standard. It does what it says on the tin! I just don’t think it’s explained a lot of the remaining questions I have. I don’t think we’re in a situation where “we can explain 80% of a given model’s behavior with the NTK” or similar. And this is relevant for e.g. studying inductive biases, as you mentioned.
But I strong upvoted your comment, because I do think deep learning theory can fill this gap—I’m personally trying to work in this area. There are some tractable-looking directions here, and people shouldn’t neglect them!
I intended my comment to apply to “theories of deep learning” in general, the NTK was only meant as an example. I agree that the NTK has problems such that it can at best be a ‘provisional’ grand theory. The big question is how to think about feature learning. At this point, though, there are a lot of contenders for “feature learning theories”—the Maximal Update Parameterization, Depth Corrections to the NTK, Perturbation Theory, Singular Learning Theory, Stochastic Collapse, SGD-Induced Sparsity....
So although I don’t think the NTK can be a final answer, I do like the idea of studying it in more depth—it provides a feature-learning-free baseline against which we can compare actual neural networks and other potential ‘grand theories’. Exactly which phenomena can we not explain with the NTK, and which theory best predicts them?
Strong upvote to Zach’s comment, it basically encapsulates my view (except that I don’t know what the “tractable-looking directions” he mentions are—Zach, can you elaborate?)
I’d turn that around: is there any explanation of why LLMs can do real-world task X and not real-world task Y that appeals to NTKs? (Not a rhetorical question: there may well be, I just haven’t seen one.)
Yeah, I can expand on that—this is obviously going be fairly opinionated, but there are a few things I’m excited about in this direction.
The first thing that comes to mind here is singular learning theory. I think all of my thoughts on DL theory are fairly strongly influenced by it at this point. It definitely doesn’t have all the answers at the moment, but it’s the single largest theory I’ve found that makes deep learning phenomena substantially “less surprising” (bonus points for these ideas preceding deep learning). For instance, one of the first things that SLT tells you is that the effective parameter count (RLCT) of your model can vary depending on the training distribution, allowing it to basically do internal model selection—the absence of bias-variance tradeoff, and the success of overparameterized models, aren’t surprising when you internalize this. The “connection to real world structure” aspect hasn’t been fully developed here, but it seems heavily suggested by the framework, in multiple ways—for instance, hierarchical statistical models are naturally singular statistical models, and the hierarchical structure is reflected in the singularities. (See also Tom Waring’s thesis).
Outside of SLT, there’s a few other areas I’m excited about—I’ll highlight just one. You mentioned Lin, Tegmark, and Rolnick—the broader literature on depth separations and the curse of dimensionality seems quite important. The approximation abilities of NNs are usually glossed over with universal approximation arguments, but this can’t be enough—for generic Lipschitz functions, universal approximation takes exponentially many parameters in the input dimension (this is a provable lower bound). So there has to be something special about the functions we care about in the real world. See this section of my post for more information. I’d highlight Poggio et al. here, which is the paper in the literature closest to my current view on this.
This isn’t a complete list, even of theoretical areas that I think could specifically help address the “real world structure” connection, but these are the two I’d feel bad not mentioning. This doesn’t include some of the more empirical findings in science of DL that I think are relevant, like simplicity bias, mode connectivity, grokking, etc. Or work outside DL that could be helpful to draw on, like Boolean circuit complexity, algorithmic information theory, natural abstractions, etc.
FWIW most potential theories of deep learning are able to explain these, I don’t think this distinguishes SLT particularly much.
Agreed—that alone isn’t particularly much, just one of the easier things to express succinctly. (Though the fact that this predates deep learning does seem significant to me. And the fact that SLT can delineate precisely where statistical learning theory went wrong here seems important too.)
Another is that can explain phenomena like phase transitions, as observed in e.g. toy models of superposition, at a quantitative level. There’s also been a substantial chunk of non-SLT ML literature that has independently rediscovered small pieces of SLT, like failures of information geometry, importance of parameter degeneracies, etc. More speculatively, but what excites me most, is that empirical phenomena like grokking, mode connectivity, and circuits seem to intuitively fit in SLT nicely, though this hasn’t been demonstrated rigorously yet.
I don’t think there are any. Of course much the same could be said of other deep learning theories and most(all?) interpretability work. The difference, as far as I can tell, is that there is a clear pathway to getting such explanations from the NTK: you’d want to do a spectral analysis of the sorts of functions learnable by transformer-NTKs. It’s just that nobody has bothered to do this! That’s why I think this line of research is neglected relative to interpretability or developing a new theoretical analysis of deep learning. Another obvious thing to try: NTKs often empirically perform comparably well to finite networks, but are usually are a few percentage points worse in accuracy. Can we say anything about the examples where the NTK fails? Do they particularly depend on ‘feature learning’? I think NTKs are a good compliment to mechinterp in this regard, since they treat the weights at each neuron as independent of all others, so they provide a good indicator of exactly which examples may require interacting ‘circuits’ to be correctly classified.
A note is that as it turns out, OthelloGPT learned a bag of heuristics, and there was no clean algorithm:
https://www.lesswrong.com/posts/gcpNuEZnxAPayaKBY/othellogpt-learned-a-bag-of-heuristics-1
What is the work that finds the algorithmic model of the game itself for Othello? I’m aware of (but not familiar with) some interpretability work on Othello-GPT (Neel Nanda’s and Kenneth Li), but thought it was just about board state representations.
Yeah, that was what I was referring to. Maybe “algorithmic model” isn’t the most precise—what we know is that the NN has an internal model of the board state that’s causal (i.e. the NN actually uses it to make predictions, as verified by interventions). Theoretically it could just be forming this internal model via a big lookup table / function approximation, rather than via a more sophisticated algorithm. Though we’ve seen from modular addition work, transformer induction heads, etc that at least some of the time NNs learn genuine algorithms.
I think that means one of the following should be surprising from theoretical perspectives:
That the model learns a representation of the board state
Or that a linear probe can recover it
That the board state is used causally
Does that seem right to you? If so, which is the surprising claim?
(I am not that informed on theoretical perspectives)
I think the core surprising thing is the fact that the model learns a representation of the board state. The causal / linear probe parts are there to ensure that you’ve defined “learns a representation of the board state” correctly—otherwise the probe could just be computing the board state itself, without that knowledge being used in the original model.
This is surprising to some older theories like statistical learning, because the model is usually treated as effectively a black box function approximator. It’s also surprising to theories like NTK, mean-field, and tensor programs, because they view model activations as IID samples from a single-neuron probability distribution—but you can’t reconstruct the board state via a permutation-invariant linear probe. The question of “which neuron is which” actually matters, so this form of feature learning is beyond them. (Though there may be e.g. perturbative modifications to these theories to allow this in a limited way).
Permutation-invariance isn’t the reason that this should be surprising. Yes, the NTK views neurons as being drawn from an IID distribution, but once they have been so drawn, you can linearly probe them as independent units. As an example, imagine that our input space consisted of five pixels, and at initialization neurons were randomly sensitive to one of the pixels. You would easily be able to construct linear probes sensitive to individual pixels even though the distribution over neurons is invariant over all the pixels.
The reason the Othello result is surprising to the NTK is that neurons implementing an “Othello board state detector” would be vanishingly rare in the initial distribution, and the NTK thinks that the neuron function distribution does not change during training.
Yeah, that’s probably the best way to explain why this is surprising from the NTK perspective. I was trying to include mean-field and tensor programs as well (where that explanation doesn’t work anymore).
Yeah, this is a good point. What I meant to specify wasn’t that you can’t recover any permutation-sensitive data at all (trivially, you can recover data about the input), but that any learned structures must be invariant to neuron permutation. (Though I’m feeling sketchy about the details of this claim). For the case of NTK, this is sort of trivial, since (as you pointed out) it doesn’t really learn features anyway.
By the way, there are actually two separate problems that come from the IID assumption: the “independent” part, and the “identically-distributed” part. For space I only really mentioned the second one. But even if you deal with the identically distributed assumption, the independence assumption still causes problems.This prevents a lot of structure from being representable—for example, a layer where “at most two neurons are activated on any input from some set” can’t be represented with independently distributed neurons. More generally a lot of circuit-style constructions require this joint structure. IMO this is actually the more fundamental limitation, though takes longer to dig into.
The same argument applies—if the distribution of intermediate neurons shifts so that Othello-board-state-detectors have a reasonably high probability of being instantiated, it will be possible to construct a linear probe detecting this, regardless of the permutation-invariance of the distribution.
This is a more reasonable objection(although actually, I’m not sure if independence does hold in the tensor programs framework—probably?)
Yeah, this “if” was the part I was claiming permutation invariance causes problems for—that identically distributed neurons probably couldn’t express something as complicated as a board-state-detector. As soon as that’s true (plus assuming the board-state-detector is implemented linearly), agreed, you can recover it with a linear probe regardless of permutation-invariance.
I probably should’ve just gone with that one, since the independence barrier is the one I usually think about, and harder to get around (related to non-free-field theories, perturbation theory, etc).
My impression from reading through one of the tensor program papers a while back was that it still makes the IID assumption, but there could be some subtlety about that I missed.
Thanks! The permutation-invariance of a bunch of theories is a helpful concept