Which solution manifolds (manifolds of zero loss) are higher dimensional than others. This is closely related to “basin flatness”, since each dimension of the manifold is a direction of zero curvature.
In relation to basin flatness and manifold dimension:
It is useful to consider the “behavioral gradients”∇θf(θ,xi) for each input.
Let G be the matrix of behavioral gradients. (The ith column of G is gi=∇θf(θ,xi)).[1] We can show that dim(manifold)≤N−Rank(G).[2]
High manifold dimension ≈ Low-rank G= Linear dependence of behavioral gradients
A case study in a very small neural network shows that “information loss” is a good qualitative interpretation of this linear dependence.
Models that throw away enough information about the input in early layers are guaranteed to live on particularly high-dimensional manifolds. Precise bounds seem easily derivable and might be given in a future post.
[Short version] Information Loss --> Basin flatness
This is an overview for advanced readers. Main post: Information Loss --> Basin flatness
Summary:
Inductive bias is related to, among other things:
Basin flatness
Which solution manifolds (manifolds of zero loss) are higher dimensional than others. This is closely related to “basin flatness”, since each dimension of the manifold is a direction of zero curvature.
In relation to basin flatness and manifold dimension:
It is useful to consider the “behavioral gradients” ∇θf(θ,xi) for each input.
Let G be the matrix of behavioral gradients. (The ith column of G is gi=∇θf(θ,xi)).[1] We can show that dim(manifold)≤N−Rank(G).[2]
Rank(Hessian)=Rank(G).[3][4]
Flat basin ≈ Low-rank Hessian = Low-rank G ≈ High manifold dimension
High manifold dimension ≈ Low-rank G = Linear dependence of behavioral gradients
A case study in a very small neural network shows that “information loss” is a good qualitative interpretation of this linear dependence.
Models that throw away enough information about the input in early layers are guaranteed to live on particularly high-dimensional manifolds. Precise bounds seem easily derivable and might be given in a future post.
See the main post for details.
In standard terminology, G is the Jacobian of the concatenation of all outputs, w.r.t. the parameters.
N is the number of parameters in the model. See claims 1 and 2 here for a proof sketch.
Proof sketch for Rank(Hessian)=Rank(G):
At a local minimum, first-order sensitivity of behavior translates to second-order sensitivity of loss.
So span(g1,..,gk)⊥ is the null space of the Hessian.
So rank(Hessian)=N−(N−rank(G))=rank(G)
There is an alternate proof going through the result Hessian=2GGT. (The constant 2 depends on MSE loss.)