First some more specific/detailed comments: Regarding the relationship with the loss and with the Hessian of the loss, my concern sort of stems from the fact that the domains/codomains are different and so I think it deserves to be spelled out. The loss of a model with parameters θ∈Θ can be described by introducing the actual function that maps the behavior to the real numbers, right? i.e. given some actual function l:Ok→R we have:
L:Θf⟶Okl⟶R
i.e. it’s l that might be something like MSE, but the function L″ is of course more mysterious because it includes the way that parameters are actually mapped to a working model. Anyway, to perform some computations with this, we are looking at an expression like
L(θ)=l(f(θ))
We want to differentiate this twice with respect to θ essentially. Firstly, we have
∇L(θ)=∇l(f(θ))Jf(θ)
where—just to keep track of this—we’ve got:
(1×N)vector=[(1×k)vector][(k×N)matrix]
Or, using ‘coordinates’ to make it explicit:
∂∂θiL(θ)=∇l(f(θ))⋅∂f∂θi=k∑p=1∇pl(f(θ))⋅∂fp∂θi
for i=1,…,N. Then for j=1,…,N we differentiate again:
This is now at the level of (N×N) matrices. Avoiding getting into any depth about tensors and indices, the D2f term is basically a (N×N×k) tensor-type object and it’s paired with ∇l which is a (1×k) vector to give something that is (N×N).
So what I think you are saying now is that if we are at a local minimum for l, then the second term on the right-hand side vanishes (because the term includes the first derivatives of l, which are zero at a minimum). You can see however that if the Hessian of l is not a multiple of the identity (like it would be for MSE), then the claimed relationship does not hold, i.e. it is not the case that in general, at a minima of l, the Hessian of the loss is equal to a constant times (Jf)TJf. So maybe you really do want to explicitly assume something like MSE.
I agree that assuming MSE, and looking at a local minimum, you have rank(Hess(L))=rank(Jf) .
(In case it’s of interest to anyone, googling turned up this recent paper https://openreview.net/forum?id=otDgw7LM7Nn which studies pretty much exactly the problem of bounding the rank of the Hessian of the loss. They say: “Flatness: A growing number of works [59–61] correlate the choice of regularizers, optimizers, or hyperparameters, with the additional flatness brought about by them at the minimum. However, the significant rank degeneracy of the Hessian, which we have provably established, also points to another source of flatness — that exists as a virtue of the compositional model structure —from the initialization itself. Thus, a prospective avenue of future work would be to compare different architectures based on this inherent kind of flatness.”)
Some broader remarks: I think these are nice observations but unfortunately I think generally I’m a bit confused/unclear about what else you might get out of going along these lines. I don’t want to sound harsh but just trying to be plain: This is mostly because, as we can see, the mathematical part of what you have said is all very simple, well-established facts about smooth functions and so it would be surprising (to me at least) if some non-trivial observation about deep learning came out from it. In a similar vein, regarding the “cause” of low-rank G, I do think that one could try to bring in a notion of “information loss” in neural networks, but for it to be substantive one needs to be careful that it’s not simply a rephrasing of what it means for the Jacobian to have less-than-full rank. Being a bit loose/informal now: To illustrate, just imagine for a moment a real-valued function on an interval. I could say it ‘loses information’ where its values cannot distinguish between a subset of points. But this is almost the same as just saying: It is constant on some subset...which is of course very close to just saying the derivative vanishes on some subset. Here, if you describe the phenomena of information loss as concretely as being the situation where some inputs can’t be distinguished, then (particularly given that you have to assume these spaces are actually some kind of smooth/differentiable spaces to do the theoretical analysis), you’ve more or less just built into your description of information loss something that looks a lot like the function being constant along some directions, which means there is a vector in the kernel of the Jacobian. I don’t think it’s somehow incorrect to point to this but it becomes more like just saying ‘perhaps one useful definition of information loss is low rank G’ as opposed to linking one phenomenon to the other.
Sorry for the very long remarks. Of course this is actually because I found it well worth engaging with. And I have a longer-standing personal interest in zero sets of smooth functions!
I will split this into a math reply, and a reply about the big picture / info loss interpretation.
Math reply:
Thanks for fleshing out the calculus rigorously; admittedly, I had not done this. Rather, I simply assumed MSE loss and proceeded largely through visual intuition.
I agree that assuming MSE, and looking at a local minimum, you have rank(Hess(L))=rank(Jf)
This is still false! Edit: I am now confused, I don’t know if it is false or not.
You are conflating ∇fl(f(θ)) and ∇θl(f(θ)). Adding disambiguation, we have:
So we see that the second term disappears if ∇fl(f(θ))=0. But the critical point condition is ∇θl(f(θ))=0. From chain rule, we have:
∇θl(f(θ))=(∇fl(f(θ)))Jθf(θ)
So it is possible to have a local minimum where ∇fl(f(θ))≠0, if ∇fl(f(θ)) is in the left null-space of Jθf(θ). There is a nice qualitative interpretation as well, but I don’t have energy/time to explain it.
However, if we are at a perfect-behavior global minimum of a regression task, then ∇fl(f(θ)) is definitely zero.
A few points about rank equality at a perfect-behavior global min:
rank(Hess(L))=rank(Jf) holds as long as Hess(l)(f(θ)) is a diagonal matrix. It need not be a multiple of the identity.
Hence, rank equality holds anytime the loss is a sum of functions s.t. each function only looks at a single component of the behavior.
If the network output is 1d (as assumed in the post), this just means that the loss is a sum over losses on individual inputs.
We can extend to larger outputs by having the behavior f be the flattened concatenation of outputs. The rank equality condition is still satisfied for MSE, Binary Cross Entropy, and Cross Entropy over a probability vector. It is not satisfied if we consider the behavior to be raw logits (before the softmax) and softmax+CrossEntropy as the loss function. But we can easily fix that by considering probability (after softmax) as behavior instead of raw logits.
In my notation, something like ∇l or Jf are functions in and of themselves. The function ∇l evaluates to zero at local minima of l.
In my notation, there isn’t any such thing as ∇fl.
But look, I think that this is perhaps getting a little too bogged down for me to want to try to neatly resolve in the comment section, and I expect to be away from work for the next few days so may not check back for a while. Personally, I would just recommend going back and slowly going through the mathematical details again, checking every step at the lowest level of detail that you can and using the notation that makes most sense to you.
Thanks for the substantive reply.
L :Θf⟶Okl⟶RFirst some more specific/detailed comments: Regarding the relationship with the loss and with the Hessian of the loss, my concern sort of stems from the fact that the domains/codomains are different and so I think it deserves to be spelled out. The loss of a model with parameters θ∈Θ can be described by introducing the actual function that maps the behavior to the real numbers, right? i.e. given some actual function l:Ok→R we have:
i.e. it’s l that might be something like MSE, but the function L″ is of course more mysterious because it includes the way that parameters are actually mapped to a working model. Anyway, to perform some computations with this, we are looking at an expression like
L(θ)=l(f(θ))We want to differentiate this twice with respect to θ essentially. Firstly, we have
∇L(θ)=∇l(f(θ))Jf(θ)where—just to keep track of this—we’ve got:
(1×N) vector=[(1×k) vector] [(k×N) matrix]Or, using ‘coordinates’ to make it explicit:
∂∂θiL(θ)=∇l(f(θ))⋅∂f∂θi=k∑p=1∇pl(f(θ))⋅∂fp∂θifor i=1,…,N. Then for j=1,…,N we differentiate again:
∂2∂θj∂θiL(θ)=k∑p=1k∑q=1∇q∇pl(f(θ))∂fq∂θj∂fp∂θi+k∑p=1∇pl(f(θ))∂fp∂θj∂θiOr,
Hess(L)(θ)=Jf(θ)T[Hess(l)(f(θ))]Jf(θ)+∇l(f(θ))D2f(θ)This is now at the level of (N×N) matrices. Avoiding getting into any depth about tensors and indices, the D2f term is basically a (N×N×k) tensor-type object and it’s paired with ∇l which is a (1×k) vector to give something that is (N×N).
So what I think you are saying now is that if we are at a local minimum for l, then the second term on the right-hand side vanishes (because the term includes the first derivatives of l, which are zero at a minimum). You can see however that if the Hessian of l is not a multiple of the identity (like it would be for MSE), then the claimed relationship does not hold, i.e. it is not the case that in general, at a minima of l, the Hessian of the loss is equal to a constant times (Jf)TJf. So maybe you really do want to explicitly assume something like MSE.
I agree that assuming MSE, and looking at a local minimum, you have rank(Hess(L))=rank(Jf) .
(In case it’s of interest to anyone, googling turned up this recent paper https://openreview.net/forum?id=otDgw7LM7Nn which studies pretty much exactly the problem of bounding the rank of the Hessian of the loss. They say: “Flatness: A growing number of works [59–61] correlate the choice of regularizers, optimizers, or hyperparameters, with the additional flatness brought about by them at the minimum. However, the significant rank degeneracy of the Hessian, which we have provably established, also points to another source of flatness — that exists as a virtue of the compositional model structure —from the initialization itself. Thus, a prospective avenue of future work would be to compare different architectures based on this inherent kind of flatness.”)
Some broader remarks: I think these are nice observations but unfortunately I think generally I’m a bit confused/unclear about what else you might get out of going along these lines. I don’t want to sound harsh but just trying to be plain: This is mostly because, as we can see, the mathematical part of what you have said is all very simple, well-established facts about smooth functions and so it would be surprising (to me at least) if some non-trivial observation about deep learning came out from it. In a similar vein, regarding the “cause” of low-rank G, I do think that one could try to bring in a notion of “information loss” in neural networks, but for it to be substantive one needs to be careful that it’s not simply a rephrasing of what it means for the Jacobian to have less-than-full rank. Being a bit loose/informal now: To illustrate, just imagine for a moment a real-valued function on an interval. I could say it ‘loses information’ where its values cannot distinguish between a subset of points. But this is almost the same as just saying: It is constant on some subset...which is of course very close to just saying the derivative vanishes on some subset. Here, if you describe the phenomena of information loss as concretely as being the situation where some inputs can’t be distinguished, then (particularly given that you have to assume these spaces are actually some kind of smooth/differentiable spaces to do the theoretical analysis), you’ve more or less just built into your description of information loss something that looks a lot like the function being constant along some directions, which means there is a vector in the kernel of the Jacobian. I don’t think it’s somehow incorrect to point to this but it becomes more like just saying ‘perhaps one useful definition of information loss is low rank G’ as opposed to linking one phenomenon to the other.
Sorry for the very long remarks. Of course this is actually because I found it well worth engaging with. And I have a longer-standing personal interest in zero sets of smooth functions!
I will split this into a math reply, and a reply about the big picture / info loss interpretation.
Math reply:
Thanks for fleshing out the calculus rigorously; admittedly, I had not done this. Rather, I simply assumed MSE loss and proceeded largely through visual intuition.
This is still false! Edit: I am now confused, I don’t know if it is false or not.
You are conflating ∇f l(f(θ)) and ∇θ l(f(θ)). Adding disambiguation, we have:
∇θ L(θ)=(∇f l(f(θ))) Jθf(θ)
Hessθ(L)(θ)=Jθf(θ)T [Hessf(l)(f(θ))] Jθf(θ)+∇f l(f(θ)) D2θf(θ)
So we see that the second term disappears if ∇f l(f(θ))=0. But the critical point condition is ∇θ l(f(θ))=0. From chain rule, we have:
∇θ l(f(θ))=(∇f l(f(θ))) Jθf(θ)
So it is possible to have a local minimum where ∇f l(f(θ))≠0, if ∇f l(f(θ)) is in the left null-space of Jθf(θ). There is a nice qualitative interpretation as well, but I don’t have energy/time to explain it.
However, if we are at a perfect-behavior global minimum of a regression task, then ∇f l(f(θ)) is definitely zero.
A few points about rank equality at a perfect-behavior global min:
rank(Hess(L))=rank(Jf) holds as long as Hess(l)(f(θ)) is a diagonal matrix. It need not be a multiple of the identity.
Hence, rank equality holds anytime the loss is a sum of functions s.t. each function only looks at a single component of the behavior.
If the network output is 1d (as assumed in the post), this just means that the loss is a sum over losses on individual inputs.
We can extend to larger outputs by having the behavior f be the flattened concatenation of outputs. The rank equality condition is still satisfied for MSE, Binary Cross Entropy, and Cross Entropy over a probability vector. It is not satisfied if we consider the behavior to be raw logits (before the softmax) and softmax+CrossEntropy as the loss function. But we can easily fix that by considering probability (after softmax) as behavior instead of raw logits.
Thanks again for the reply.
In my notation, something like ∇l or Jf are functions in and of themselves. The function ∇l evaluates to zero at local minima of l.
In my notation, there isn’t any such thing as ∇fl.
But look, I think that this is perhaps getting a little too bogged down for me to want to try to neatly resolve in the comment section, and I expect to be away from work for the next few days so may not check back for a while. Personally, I would just recommend going back and slowly going through the mathematical details again, checking every step at the lowest level of detail that you can and using the notation that makes most sense to you.