There was a period where everyone was really into basin broadness for measuring neural network generalization. This mostly stopped being fashionable, but I’m not sure if there’s enough written up on why it didn’t do much, so I thought I should give my take for why I stopped finding it attractive. This is probably a repetition of what others have found, but I thought I might as well repeat it.
Let’s say we have a neural network fw(x):Rn. We evaluate it on a dataset (x,y)∼D using a loss function L(^y,y):R, to find an optimum w∗=argminwE(x,y)∼D[L(fw(x),y)]. Then there was an idea going around that the Hessian matrix (i.e. the second derivative of E(x,y)∼D[L(fw(x),y)] at w∗) would tell us something about w∗ (especially about how well it generalizes).
If we number the dataset (xi,yi), we can stack all the network outputs ^yi(w)=fw(xi) which fits into an empirical loss ^L(^y)=1n∑ni=1L(^yi,yi). The Hessian that we talked about before is now just the Hessian of ^L(^y(w)). Expanding this out is kind of clunky since it involves some convoluted tensors that I don’t know any syntax for, but clearly it consists of two terms:
The Hessian of ^L with a pair of the Jacobian of ^y on each end (this can just barely be written without crazy tensors: (Jw^y(w))T(H^y^L(^y))∣^y(w)Jw^y(w))
The gradient of ^L with a crazy second derivative of ^y.
Now, the derivatives of ^L are “obviously boring” because they don’t really refer to the neural network weights, which is confirmed if you think about it in concrete cases, e.g. if L(^y,y)=−ylog(^y)−(1−y)log(1−^y) with y=1 or y=0, the derivatives just quantify how far ^y is from y. This obviously isn’t relevant for neural network generalization, except in the sense that it tells you which direction you want to generalize in.
Meanwhile, Jw^y(w) is incredibly strongly related to neural network generalization, because it’s literally a matrix which specifies how the neural network outputs change in response the weights. In fact, it forms the core of the neural tangent kernel (a standard tool for modelling neural network generalization), because the NTK can be expressed as Jw^y(w)(Jw^y(w))T.
The “crazy second derivative of ^y” can I guess be understood separately for each ^yi, as then it’s just the Hessian Hw^yi(w), i.e. it reflects how changes in the weights interact with each other when influencing ^yi. I don’t have any strong opinions on how important this matrix is, though because Jw^y(w) is so obviously important, I haven’t felt like granting Hw^yi(w) much attention.
The NTK as the network activations?
Epistemic status: speculative, I really should get around to verifying it. Really the prior part is speculative too, but I think those speculations are more theoretically well-grounded. But if I’m wrong with either, please call me a dummy in the comments so I can correct.
Let’s take the simplest case of a linear network, fw(x)=wTx. In this case, Jw^y(w)=xT, i.e. the Jacobian is literally just the inputs to the network. If you work out a bunch of other toy examples, the takeaway is qualitatively similar (the Jacobian is closely related to the neuron activations), though not exactly the same.
There are of course some exceptions, e.g.fa,b(x)=abx at a=b=0 just has a zero Jacobian. Exceptions this extreme are probably rare, but more commonly you could have some softmax in the network (e.g. in an attention layer) which saturates such that no gradient goes through. In that case for e.g. interpretability, it seems like you’d often still really want to “count” this, so arguably the activations would be better than the NTK for this case. (I’ve been working on a modification to the NTK to better handle this case.)
The NTK and the network activations have somewhat different properties and so it switches which one I consider most relevant. However, my choice tends to be more driven by analytical convenience (e.g. the NTK and the network activations lie in different vector spaces) than by anything else.
Why I stopped being into basin broadness
There was a period where everyone was really into basin broadness for measuring neural network generalization. This mostly stopped being fashionable, but I’m not sure if there’s enough written up on why it didn’t do much, so I thought I should give my take for why I stopped finding it attractive. This is probably a repetition of what others have found, but I thought I might as well repeat it.
Let’s say we have a neural network fw(x):Rn. We evaluate it on a dataset (x,y)∼D using a loss function L(^y,y):R, to find an optimum w∗=argminwE(x,y)∼D[L(fw(x),y)]. Then there was an idea going around that the Hessian matrix (i.e. the second derivative of E(x,y)∼D[L(fw(x),y)] at w∗) would tell us something about w∗ (especially about how well it generalizes).
If we number the dataset (xi,yi), we can stack all the network outputs ^yi(w)=fw(xi) which fits into an empirical loss ^L(^y)=1n∑ni=1L(^yi,yi). The Hessian that we talked about before is now just the Hessian of ^L(^y(w)). Expanding this out is kind of clunky since it involves some convoluted tensors that I don’t know any syntax for, but clearly it consists of two terms:
The Hessian of ^L with a pair of the Jacobian of ^y on each end (this can just barely be written without crazy tensors: (Jw^y(w))T(H^y^L(^y))∣^y(w)Jw^y(w))
The gradient of ^L with a crazy second derivative of ^y.
Now, the derivatives of ^L are “obviously boring” because they don’t really refer to the neural network weights, which is confirmed if you think about it in concrete cases, e.g. if L(^y,y)=−ylog(^y)−(1−y)log(1−^y) with y=1 or y=0, the derivatives just quantify how far ^y is from y. This obviously isn’t relevant for neural network generalization, except in the sense that it tells you which direction you want to generalize in.
Meanwhile, Jw^y(w) is incredibly strongly related to neural network generalization, because it’s literally a matrix which specifies how the neural network outputs change in response the weights. In fact, it forms the core of the neural tangent kernel (a standard tool for modelling neural network generalization), because the NTK can be expressed as Jw^y(w)(Jw^y(w))T.
The “crazy second derivative of ^y” can I guess be understood separately for each ^yi, as then it’s just the Hessian Hw^yi(w), i.e. it reflects how changes in the weights interact with each other when influencing ^yi. I don’t have any strong opinions on how important this matrix is, though because Jw^y(w) is so obviously important, I haven’t felt like granting Hw^yi(w) much attention.
The NTK as the network activations?
Epistemic status: speculative, I really should get around to verifying it. Really the prior part is speculative too, but I think those speculations are more theoretically well-grounded. But if I’m wrong with either, please call me a dummy in the comments so I can correct.
Let’s take the simplest case of a linear network, fw(x)=wTx. In this case, Jw^y(w)=xT, i.e. the Jacobian is literally just the inputs to the network. If you work out a bunch of other toy examples, the takeaway is qualitatively similar (the Jacobian is closely related to the neuron activations), though not exactly the same.
There are of course some exceptions, e.g.fa,b(x)=abx at a=b=0 just has a zero Jacobian. Exceptions this extreme are probably rare, but more commonly you could have some softmax in the network (e.g. in an attention layer) which saturates such that no gradient goes through. In that case for e.g. interpretability, it seems like you’d often still really want to “count” this, so arguably the activations would be better than the NTK for this case. (I’ve been working on a modification to the NTK to better handle this case.)
The NTK and the network activations have somewhat different properties and so it switches which one I consider most relevant. However, my choice tends to be more driven by analytical convenience (e.g. the NTK and the network activations lie in different vector spaces) than by anything else.