During training on the ‘old task’, NTK stays in the ‘tangent space’ of the network’s initialization. This means that, to first order, none of the functions/derivatives computed by the individual neurons change at all, only the output function does.
Eh? Why does this follow? Derivatives make sense; the derivatives staying approximately-constant is one of the assumptions underlying NTK to begin with. But the functions computed by individual neurons should be able to change for exactly the same reason the output function changes, assuming the network has more than one layer. What am I missing here?
The asymmetry between the output function and the intermediate neuron functions comes from backprop—from the fact that the gradients are backprop-ed through weight matrices with entries of magnitude O(N−12). So the gradient of the output w.r.t itself is obviously 1, then the gradient of the output w.r.t each neuron in the preceding layer is O(N−12), since you’re just multiplying by a vector with those entries. Then by induction all other preceding layers’ gradients are the sum of N random things of size O(1/N), and so are of size O(N−12) again. So taking a step of backprop will change the output function by O(1) but the intermediate functions by O(N−12), vanishing in the large-width limit.
(This is kind of an oversimplification since it is possible to have changing intermediate functions while doing backprop, as mentioned in the linked paper. But this is the essence of why it’s possible in some limits to move around using backprop without changing the intermediate neurons)
Ok, that’s at least a plausible argument, although there are some big loopholes. Main problem which jumps out to me: what happens after one step of backprop is not the relevant question. One step of backprop is not enough to solve a set of linear equations (i.e. to achieve perfect prediction on the training set); the relevant question is what happens after one step of Newton’s method, or after enough steps of gradient descent to achieve convergence.
What would convince me more is an empirical result—i.e. looking at the internals of an actual NTK model, trying the sort of tricks which work well for interpreting normal NNs, and seeing how well they work. Just relying on proofs makes it way too easy for an inaccurate assumption to sneak in—like the assumption that we’re only using one step of backprop. If anyone has tried that sort of empirical work, I’d be interested to hear what it found.
The result that NTK does not learn features in the large N limit is not in dispute at all—it’s right there on page 15 of the original NTK paper, and indeed holds after arbitrarily many steps of backprop. I don’t think that there’s really much room for loopholes in the math here. See Greg Yang’s paper for a lengthy proof that this holds for all architectures. Also worth noting that when people ‘take the NTK limit’ they often don’t initialize an actual net at all, they instead use analytical expressions for what the inner product of the gradients would be at N=infinity to compute the kernel directly.
Alright, I buy the argument on page 15 of the original NTK paper.
I’m still very skeptical of the interpretation of this as “NTK models can’t learn features”. In general, when someone proves some interesting result which seems to contradict some combination of empirical results, my default assumption is that the proven result is being interpreted incorrectly, so I have a high prior that that’s what’s happening here. In this case, it could be that e.g. the “features” relevant to things like transfer learning are not individual neuron activations—e.g. IIRC much of the circuit interpretability work involves linear combinations of activations, which would indeed circumvent this theorem.
This whole class of concerns would be ruled out by empirical results—e.g. experimental evidence on transfer learning with NTKs, or someone applying the same circuit interpretability techniques to NTKs which are applied to standard nets.
I don’t think taking linear combinations will help, because adding terms to the linear combination will also increase the magnitude of the original activation vector—e.g. if you add together N12 units, the magnitude of the sum of their original activations will with high probability be Θ(N14), dwarfing the O(1) change due to change in the activations. But regardless, it can’t help with transfer learning at all, since the tangent kernel(which determines learning in this regime) doesn’t change by definition.
What empirical results do you think are being contradicted? As far as I can tell, the empirical results we have are ‘NTK/GP have similar performance to neural nets on some, but not all, tasks’. I don’t think transfer/feature learning is addressed at all. You might say these results are suggestive evidence that NTK/GP captures everything important about neural nets, but this is precisely what is being disputed with the transfer learning arguments.
I can imagine doing an experiment where we find the ‘empirical tangent kernel’ of some finite neural net at initialization, solve the linear system, and then analyze the activations of the resulting network. But it’s worth noting that this is not what is usually meant by ‘NTK’—that usually includes taking the infinite-width limit at the same time. And to the extent that we expect the activations to change at all, we no longer have reason to think that this linear system is a good approximation of SGD. That’s what the above mathematical results mean—the same mathematical analysis that implies that network training is like solving a linear system, also implies that the activations don’t change at all.
They wouldn’t be random linear combinations, so the central limit theorem estimate wouldn’t directly apply. E.g. this circuit transparency work basically ran PCA on activations. It’s not immediately obvious to me what the right big-O estimate would be, but intuitively, I’d expect the PCA to pick out exactly those components dominated by change in activations—since those will be the components which involve large correlations in the activation patterns across data points (at least that’s my intuition).
I think this claim is basically wrong:
And to the extent that we expect the activations to change at all, we no longer have reason to think that this linear system is a good approximation of SGD.
There’s a very big difference between “no change to first/second order” and “no change”. Even in the limit, we do expect most linear combinations of the activations to change. And those are exactly the changes which would potentially be useful for transfer learning. And the tangent kernel not changing does not imply that transfer learning won’t work, for two reasons: starting at a better point can accelerate convergence, and (probably more relevant) the starting point can influence the solution chosen when the linear system is underdetermined (which it is, if I understand things correctly).
I do think the empirical results pretty strongly suggest that the NTK/GP model captures everything important about neural nets, at least in terms of their performance on the original task. If that’s true, and NTKs can’t be used for transfer learning, then that would imply that transfer learning in normal nets works for completely different reasons from good performance on the original task, and that good performance on the original task has nothing to do with learning features. Those both strike me as less plausible than these proofs about “NTK not learning features” being misinterpreted.
(I also did a quick google search for transfer learning with NTKs. I only found one directly-relevant study, which is on way too small and simple a system for me to draw much of a conclusion from it, but it does seem to have worked.)
BTW, thanks for humoring me throughout this thread. This is really useful, and my understanding is updating considerably.
Hmm, so regarding the linear combinations, it’s true that there are some linear combinations that will change by Θ(1) in the large-width limit—just use the vector of partial derivatives of the output at some particular input, this sum will change by the amount that the output function moves during the regression. Indeed, I suspect(but don’t have a proof) that these particular combinations will span the space of linear combinations that change non-trivially during training. I would dispute “we expect most linear combinations to change” though—the CLT argument implies that we should expect almost all combinations to not appreciably change. Not sure what effect this would have on the PCA and still think it’s plausible that it doesn’t change at all(actually, I think Greg Yang states that it doesn’t change in section 9 of his paper, haven’t read that part super carefully though)
And the tangent kernel not changing does not imply that transfer learning won’t work
So I think I was a bit careless in saying that the NTK can’t do transfer learning at all—a more exact statement might be “the NTK does the minimal amount of transfer learning possible”. What I mean by this is, any learning algorithm can do transfer learning if the task we are ‘transferring’ to is sufficiently similar to the original task—for instance, if it’s just the exact same task but with a different data sample. I claim that the ‘transfer learning’ the NTK does is of this sort. As you say, since the tangent kernel doesn’t change at all, the net effect is to move where the network starts in the tangent space. Disregarding convergence speed, the impact this has on generalization is determined by the values set by the old function on axes of the NTK outside of the span of the partial derivatives at the new function’s data points. This means that, for the NTK to transfer anything from one task to another, it’s not enough for the tasks to both feature, for instance, eyes. It’s that the eyes have to correlate with the output in the exact same way in both tasks. Indeed, the transfer learning could actually hurt the generalization. Nor is its effect invariant under simple transformations like flipping the sign of the target function(this would change beneficial transfer to harmful). By default, for functions that aren’t simple multiples, I expect the linear correlation between values on different axes to be about 0, even if the functions share many meaningful features. So while the NTK can do ‘transfer learning’ in a sense, it’s about as weak as possible, and I strongly doubt that this sort of transfer is sufficient to explain transfer learning’s successes in practice(but don’t have empirical proof).
I do think the empirical results pretty strongly suggest that the NTK/GP model captures everything important about neural nets, at least in terms of their performance on the original task.
It’s true that NTK/GP perform pretty closely to finite nets on the tasks we’ve tried them on so far, but those tasks are pretty simple and we already had decent non-NN solutions. Generally the pattern is ’”GP matches NNs on really simple tasks, NTK on somewhat harder ones”. I think the data we have is consistent with this breaking down as we move to the harder problems that have no good non-NN solutions. I would be very interested in seeing an experiment with NTK on, say, ImageNet for this reason, but as far as I know no one’s done so because of the prohibitive computational cost.
I only found one directly-relevant study, which is on way too small and simple a system for me to draw much of a conclusion from it, but it does seem to have worked.
Thanks for the link—will read this tomorrow.
BTW, thanks for humoring me throughout this thread. This is really useful, and my understanding is updating considerably.
And thank you for engaging in detail—I have also found this very helpful in forcing me to clarify(partially to myself) what my actual beliefs are.
So I read through the Maddox et al. study, and it definitely does not show that the NTK can do transfer learning. They pre-train using SGD on a single task, then use the NTK computed on the trained network to do Bayesian inference on some other tasks. They say in a footnote on page 9, “Note that in theory, there is no need to train the network at all. We found that it is practically useful to train the network to learn good representations.” This makes me suspect that they tried using the NTK to learn the transfer parameters but it didn’t work.
Regarding the empirical results about the NTK explaining the performance of neural nets, I found this study interesting. They computed the ‘empirical NTK’ on some finite-width networks and compared the performance of the solution found by SGD to that found by solving the NTK. For standard widths, the NTK solution performed substantially worse(up to 20% drop in accuracy). The gap closed to some extent, but not completely, upon making the network much wider. The size of the gap also correlated with the complexity of the task(0.5% gap for MNIST, 13% for CIFAR, 18% for a subset of ImageNet). The trajectory of the weights also diverged substantially from the NTK prediction, even on MNIST. All of this seems consistent with the NTK being a decent first-order approximation that breaks down on the really hard tasks that require the networks to do non-trivial feature learning.
Ah, that is interesting. This definitely updates me moderately toward the “NTKs don’t learn features” hypothesis.
BTW, does this hypothesis also mean that feature learning should break down in ordinary nets as they scale up? Or does increasing the data alongside the parameter count counteract that?
I think nets are usually increased in depth as well as width when they are ‘scaled up’, so the NTK limit doesn’t apply—the convergence to NTK is controlled by the ratio of depth to width, only approaching a deterministic kernel if this ratio approaches 0.
Eh? Why does this follow? Derivatives make sense; the derivatives staying approximately-constant is one of the assumptions underlying NTK to begin with. But the functions computed by individual neurons should be able to change for exactly the same reason the output function changes, assuming the network has more than one layer. What am I missing here?
The asymmetry between the output function and the intermediate neuron functions comes from backprop—from the fact that the gradients are backprop-ed through weight matrices with entries of magnitude O(N−12). So the gradient of the output w.r.t itself is obviously 1, then the gradient of the output w.r.t each neuron in the preceding layer is O(N−12), since you’re just multiplying by a vector with those entries. Then by induction all other preceding layers’ gradients are the sum of N random things of size O(1/N), and so are of size O(N−12) again. So taking a step of backprop will change the output function by O(1) but the intermediate functions by O(N−12), vanishing in the large-width limit.
(This is kind of an oversimplification since it is possible to have changing intermediate functions while doing backprop, as mentioned in the linked paper. But this is the essence of why it’s possible in some limits to move around using backprop without changing the intermediate neurons)
Ok, that’s at least a plausible argument, although there are some big loopholes. Main problem which jumps out to me: what happens after one step of backprop is not the relevant question. One step of backprop is not enough to solve a set of linear equations (i.e. to achieve perfect prediction on the training set); the relevant question is what happens after one step of Newton’s method, or after enough steps of gradient descent to achieve convergence.
What would convince me more is an empirical result—i.e. looking at the internals of an actual NTK model, trying the sort of tricks which work well for interpreting normal NNs, and seeing how well they work. Just relying on proofs makes it way too easy for an inaccurate assumption to sneak in—like the assumption that we’re only using one step of backprop. If anyone has tried that sort of empirical work, I’d be interested to hear what it found.
The result that NTK does not learn features in the large N limit is not in dispute at all—it’s right there on page 15 of the original NTK paper, and indeed holds after arbitrarily many steps of backprop. I don’t think that there’s really much room for loopholes in the math here. See Greg Yang’s paper for a lengthy proof that this holds for all architectures. Also worth noting that when people ‘take the NTK limit’ they often don’t initialize an actual net at all, they instead use analytical expressions for what the inner product of the gradients would be at N=infinity to compute the kernel directly.
Alright, I buy the argument on page 15 of the original NTK paper.
I’m still very skeptical of the interpretation of this as “NTK models can’t learn features”. In general, when someone proves some interesting result which seems to contradict some combination of empirical results, my default assumption is that the proven result is being interpreted incorrectly, so I have a high prior that that’s what’s happening here. In this case, it could be that e.g. the “features” relevant to things like transfer learning are not individual neuron activations—e.g. IIRC much of the circuit interpretability work involves linear combinations of activations, which would indeed circumvent this theorem.
This whole class of concerns would be ruled out by empirical results—e.g. experimental evidence on transfer learning with NTKs, or someone applying the same circuit interpretability techniques to NTKs which are applied to standard nets.
I don’t think taking linear combinations will help, because adding terms to the linear combination will also increase the magnitude of the original activation vector—e.g. if you add together N12 units, the magnitude of the sum of their original activations will with high probability be Θ(N14), dwarfing the O(1) change due to change in the activations. But regardless, it can’t help with transfer learning at all, since the tangent kernel(which determines learning in this regime) doesn’t change by definition.
What empirical results do you think are being contradicted? As far as I can tell, the empirical results we have are ‘NTK/GP have similar performance to neural nets on some, but not all, tasks’. I don’t think transfer/feature learning is addressed at all. You might say these results are suggestive evidence that NTK/GP captures everything important about neural nets, but this is precisely what is being disputed with the transfer learning arguments.
I can imagine doing an experiment where we find the ‘empirical tangent kernel’ of some finite neural net at initialization, solve the linear system, and then analyze the activations of the resulting network. But it’s worth noting that this is not what is usually meant by ‘NTK’—that usually includes taking the infinite-width limit at the same time. And to the extent that we expect the activations to change at all, we no longer have reason to think that this linear system is a good approximation of SGD. That’s what the above mathematical results mean—the same mathematical analysis that implies that network training is like solving a linear system, also implies that the activations don’t change at all.
They wouldn’t be random linear combinations, so the central limit theorem estimate wouldn’t directly apply. E.g. this circuit transparency work basically ran PCA on activations. It’s not immediately obvious to me what the right big-O estimate would be, but intuitively, I’d expect the PCA to pick out exactly those components dominated by change in activations—since those will be the components which involve large correlations in the activation patterns across data points (at least that’s my intuition).
I think this claim is basically wrong:
There’s a very big difference between “no change to first/second order” and “no change”. Even in the limit, we do expect most linear combinations of the activations to change. And those are exactly the changes which would potentially be useful for transfer learning. And the tangent kernel not changing does not imply that transfer learning won’t work, for two reasons: starting at a better point can accelerate convergence, and (probably more relevant) the starting point can influence the solution chosen when the linear system is underdetermined (which it is, if I understand things correctly).
I do think the empirical results pretty strongly suggest that the NTK/GP model captures everything important about neural nets, at least in terms of their performance on the original task. If that’s true, and NTKs can’t be used for transfer learning, then that would imply that transfer learning in normal nets works for completely different reasons from good performance on the original task, and that good performance on the original task has nothing to do with learning features. Those both strike me as less plausible than these proofs about “NTK not learning features” being misinterpreted.
(I also did a quick google search for transfer learning with NTKs. I only found one directly-relevant study, which is on way too small and simple a system for me to draw much of a conclusion from it, but it does seem to have worked.)
BTW, thanks for humoring me throughout this thread. This is really useful, and my understanding is updating considerably.
Hmm, so regarding the linear combinations, it’s true that there are some linear combinations that will change by Θ(1) in the large-width limit—just use the vector of partial derivatives of the output at some particular input, this sum will change by the amount that the output function moves during the regression. Indeed, I suspect(but don’t have a proof) that these particular combinations will span the space of linear combinations that change non-trivially during training. I would dispute “we expect most linear combinations to change” though—the CLT argument implies that we should expect almost all combinations to not appreciably change. Not sure what effect this would have on the PCA and still think it’s plausible that it doesn’t change at all(actually, I think Greg Yang states that it doesn’t change in section 9 of his paper, haven’t read that part super carefully though)
So I think I was a bit careless in saying that the NTK can’t do transfer learning at all—a more exact statement might be “the NTK does the minimal amount of transfer learning possible”. What I mean by this is, any learning algorithm can do transfer learning if the task we are ‘transferring’ to is sufficiently similar to the original task—for instance, if it’s just the exact same task but with a different data sample. I claim that the ‘transfer learning’ the NTK does is of this sort. As you say, since the tangent kernel doesn’t change at all, the net effect is to move where the network starts in the tangent space. Disregarding convergence speed, the impact this has on generalization is determined by the values set by the old function on axes of the NTK outside of the span of the partial derivatives at the new function’s data points. This means that, for the NTK to transfer anything from one task to another, it’s not enough for the tasks to both feature, for instance, eyes. It’s that the eyes have to correlate with the output in the exact same way in both tasks. Indeed, the transfer learning could actually hurt the generalization. Nor is its effect invariant under simple transformations like flipping the sign of the target function(this would change beneficial transfer to harmful). By default, for functions that aren’t simple multiples, I expect the linear correlation between values on different axes to be about 0, even if the functions share many meaningful features. So while the NTK can do ‘transfer learning’ in a sense, it’s about as weak as possible, and I strongly doubt that this sort of transfer is sufficient to explain transfer learning’s successes in practice(but don’t have empirical proof).
It’s true that NTK/GP perform pretty closely to finite nets on the tasks we’ve tried them on so far, but those tasks are pretty simple and we already had decent non-NN solutions. Generally the pattern is ’”GP matches NNs on really simple tasks, NTK on somewhat harder ones”. I think the data we have is consistent with this breaking down as we move to the harder problems that have no good non-NN solutions. I would be very interested in seeing an experiment with NTK on, say, ImageNet for this reason, but as far as I know no one’s done so because of the prohibitive computational cost.
Thanks for the link—will read this tomorrow.
And thank you for engaging in detail—I have also found this very helpful in forcing me to clarify(partially to myself) what my actual beliefs are.
So I read through the Maddox et al. study, and it definitely does not show that the NTK can do transfer learning. They pre-train using SGD on a single task, then use the NTK computed on the trained network to do Bayesian inference on some other tasks. They say in a footnote on page 9, “Note that in theory, there is no need to train the network at all. We found that it is practically useful to train the network to learn good representations.” This makes me suspect that they tried using the NTK to learn the transfer parameters but it didn’t work.
Regarding the empirical results about the NTK explaining the performance of neural nets, I found this study interesting. They computed the ‘empirical NTK’ on some finite-width networks and compared the performance of the solution found by SGD to that found by solving the NTK. For standard widths, the NTK solution performed substantially worse(up to 20% drop in accuracy). The gap closed to some extent, but not completely, upon making the network much wider. The size of the gap also correlated with the complexity of the task(0.5% gap for MNIST, 13% for CIFAR, 18% for a subset of ImageNet). The trajectory of the weights also diverged substantially from the NTK prediction, even on MNIST. All of this seems consistent with the NTK being a decent first-order approximation that breaks down on the really hard tasks that require the networks to do non-trivial feature learning.
Ah, that is interesting. This definitely updates me moderately toward the “NTKs don’t learn features” hypothesis.
BTW, does this hypothesis also mean that feature learning should break down in ordinary nets as they scale up? Or does increasing the data alongside the parameter count counteract that?
I think nets are usually increased in depth as well as width when they are ‘scaled up’, so the NTK limit doesn’t apply—the convergence to NTK is controlled by the ratio of depth to width, only approaching a deterministic kernel if this ratio approaches 0.