This work was done under the mentorship of Evan Hubinger through the SERI MATS program. Thanks to Lucius Bushnaq, John Wentworth, Quintin Pope, and Peter Barnett for useful feedback and suggestions.
In this theory, the main proximate cause of flat basins is a type of information loss. Its relationship with circuit complexity and Kolmogorov complexity is currently unknown to me.[1] In this post, I will demonstrate that:
High-dimensional solution manifolds are caused by linear dependence between the “behavioral gradients” for different inputs.
This linear dependence is usually caused when networks throw away information which distinguishes different training inputs. It is more likely to occur when the information is thrown away early or by RELU.
Suppose we have a regression task with 1-dimensional labels and k training examples. Let us take an overparameterized network with N parameters. Every model in parameter space is part of a manifold, where every point on that manifold has identical behavior on the training set. These manifolds are usually[2] at least N−k dimensional, but some are higher dimensional than this. I will call these manifolds “behavior manifolds”, since points on the same manifold have the same behavior (on the training set, not on all possible inputs).
We can visualize the existence of “behavior manifolds” by starting with a blank parameter space, then adding contour planes for each training example. Before we add any contour planes, the entire parameter space is a single manifold, with “identical behavior” on the null set. First, let us add the contour planes for input 1:
Each plane here is an n-1 dimensional manifold, where every model on that plane has the same output on input 1. They slice parameter space into n-1 dimensional regions. Each of these regions is an equivalence class of functions, which all behave about the same on input 1.
Next, we can add contour planes for input 2:
When we put them together, they look like this:
Together, the contours slice parameter space into n-2 dimensional regions. Each “diamond” in the picture is the cross-section of a tube-like region which extends vertically, in the direction which is parallel to both sets of planes. The manifolds of constant behavior are lines which run vertically through these tubes, parallel to both sets of contours.
In higher dimensions, these “lines” and “tubes” are actually n-2 dimensional hyperplanes, since only two degrees of freedom have been removed, one by each set of contours.
We can continue this with more and more inputs. Each input adds another set of hyperplanes, and subtracts one more dimension from the identical-behavior manifolds. Since each input can only slice off one dimension, the manifolds of constant behavior are at least n-k dimensional, where k is the number of training examples.[3]
Solution manifolds
Global minima also lie on behavior manifolds, such that every point on the manifold is a global minimum. I will call these “solution manifolds”. These manifolds generally extend out to infinity, so it isn’t really meaningful to talk about literal “basin volume”.[4] We can focus instead on their dimensionality. All else being equal, a higher dimensional solution manifold should drain a larger region of parameter space, and thus be favored by the inductive bias.[5]
Parallel contours allow higher manifold dimension
Suppose we have 3 parameters (one is off-the-page) and 2 inputs. If the contours are perpendicular:
Then the green regions are cross-sections of tubes extending infinitely off-the-page, where each tube contains models that are roughly equivalent on the training set. The behavior manifolds are lines (1d) running in the out-of-page direction. The black dots are the cross-sections of these lines.
However, if the contours are parallel:
Now the behavior manifolds are planes, running parallel to the contours. So we see here that parallel contours allow behavioral manifolds to have dimension>N−k.
In the next section, I will establish the following fact:
Key result: If a behavioral manifold is more than N−k dimensional, then the normal vectors of the contours must be linearly dependent. The degree of linear independence (the dimensionality of the span) controls the allowed manifold dimensionality.
Behavioral gradients
The normal vector of a contour for input xi is the gradient of the network’s output on that input. If we denote the network output as f(θ,xi), then the normal vector is gi=∇θf(θ,xi). I will call this vector the behavioral gradient, to distinguish it from the gradient of loss.
We can put these behavioral gradients into a matrix G, the matrix of behavioral gradients. The ith column of G is the ith behavioral gradient gi=∇θf(θ,xi).[6]
Now the rank of G is the span of the behavioral gradients. If they are all parallel, Rank(G)=1. If they are all linearly independent, Rank(G)=k, where k is the number of inputs.
Claim 1: The space spanned by the behavioral gradients at a point is perpendicular to the behavioral manifold at that point.[7]
Proof sketch:
The behavioral gradients tell you the first-order sensitivity of the outputs to parameter movement
If you move a small distance ds in parallel to the manifold, then your distance from the manifold goes as ds2
So the change in output also goes as ds2
So the output is only second-order sensitive to movement in this direction
So none of the behavioral gradients have a component in this direction
The first part follows trivially from Claim 1, since two orthogonal spaces in RN cannot have their dimensions sum to more than N. The second part is true by definition of G.
So we have our key result: If dim(manifold)>N−k, then Rank(G)<k, meaning that the behavioral gradients are not linearly independent. The more linearly dependent they are, the lower Rank(G) is, and the higher dim(manifold) is allowed to be.
High manifold dimension ≈ Low-rank G= Linear dependence of behavioral gradients
Claim 3: At a local minimum, Rank(Hessian(Loss))=Rank(G).
The purpose of this claim is to connect our formalism with the Hessian of loss, which is used as a measure of basin sharpness. In qualitative terms:
span(g1,..,gk)⊥[8] is the set of directions in which the output is not first-order sensitive to parameter change. Its dimensionality is N−rank(G).
At a local minimum, first-order sensitivity of behavior translates to second-order sensitivity of loss.
So span(g1,..,gk)⊥ is the null space of the Hessian.
So rank(Hessian)=N−(N−rank(G))=rank(G)
See this footnote[9] for an different proof sketch, which includes the result Hessian=2GGT.[10]
Low rank G indicates information loss
Brief summary:
A network is said to “throw away” information distinguishing a set of m inputs if their activations are identical at some intermediate layer L. When this happens, the behavioral gradients for the two inputs are identical in all layers after L. This greatly increases the chance of linear dependence, since the gradients can now only differ beforelayer L.
If the number of parameters before L is less than m, then there is guaranteed to be linear dependence. Destruction of information by RELUs often zeros out gradients before L as well, making the effect even stronger.
Hence, information loss often leads to linear dependence of behavioral gradients, which in turn causes low Hessian rank, basin flatness, and high manifold dimension.[11]
For a more detailed explanation, including a case study, see this video presentation:
Follow-up question and extra stuff:
Empirical results and immediate next steps
I am currently running experiments and further theoretical analysis to understand the following:
Is manifold dimensionality actually a good predictor of which solution will be found?
Are {info loss / Hessian rank} and {manifold dimension} related to circuit complexity? In what way? Which one is more related?
So far the results have been somewhat surprising. Specifically, circuit complexity and manifold dimension do not seem very predictive of which solution will be found in very small networks (~20 params / 4 data pts). I expect my understanding to change a lot over the next week.
Update (5/19): Further experiments on ultra-small datasets indicate that the more overparameterized the network is, the less likely we are to find a solution with non-full rank Hessian. My guess is that this is due to increased availability of less-flat basins. Yet the generalization behavior becomes more consistent across runs, not less, and converges to something which looks very natural but can’t be modeled by any simple circuit. I think this is related to the infinite-width / NTK stuff. I am currently quite confused.
I first thought that circuit simplicity was the direct cause of flat basins. I later thought that it was indirectly associated with flat basins due to a close correlation with information loss. However, recent experiments have updated me towards seeing circuit complexity as much less predictive than I expected, and with a much looser connection to info loss and basin flatness. I am very uncertain about all this, and expect to have a clearer picture in a week or two.
Technically, there can be lower dimensional manifolds than this, but they account for 0% of the hypervolume of parameter space. Whereas manifold classes of N−k≤dim≤N can all have non-zero amounts of hypervolume.
Technically, you can also get manifolds with dim<N−k. For instance, suppose that the contours for input 1 are concentric spheres centered at (-1, 0, 0), and the contours for input 2 are spheres centered at (1, 0, 0). Then all points on the x-axis are unique in behavior, so they are on 0-dimensional manifolds, instead of the expected N−k=1.
This kind of thing usually occurs somewhere in parameter space, but I will treat it as an edge case. The regions with this phenomenon always have vanishing measure.
This can be repaired by using “initialization-weighted” volumes. L2 and L1 regularization also fix the problem; adding a regularization term to the loss shifts the solution manifolds and can collapse them to points.
Let H() denote taking the Hessian w.r.t. parameters θ. Inputs are xi, labels are yi, network output is f(θ,xi). L(θ) is loss over all training inputs, and L(θ,xi) is loss on a particular input.
For simplicity, we will center everything at the local min such that θ=0 at the local min and L(θlocalmin)=L(0)=0. Assume MSE loss.
Let gi=∇θf(θ,xi) be the ith behavioral gradient.
H(L(θ,xi))v=2gi(gi⋅v)=2gigTiv (For any vector v)
H(L(θ,xi))=2gigTi
H(L(θ))=∑iH(L(θ,xi))=2GGT
Rank(H(L(θ)))=Rank(GGT)=Rank(G)
(Since Rank(AAT)=Rank(A) for any real-valued matrix A)
Information Loss --> Basin flatness
This work was done under the mentorship of Evan Hubinger through the SERI MATS program. Thanks to Lucius Bushnaq, John Wentworth, Quintin Pope, and Peter Barnett for useful feedback and suggestions.
In this theory, the main proximate cause of flat basins is a type of information loss. Its relationship with circuit complexity and Kolmogorov complexity is currently unknown to me.[1] In this post, I will demonstrate that:
High-dimensional solution manifolds are caused by linear dependence between the “behavioral gradients” for different inputs.
This linear dependence is usually caused when networks throw away information which distinguishes different training inputs. It is more likely to occur when the information is thrown away early or by RELU.
Overview for advanced readers: [Short version] Information Loss --> Basin flatness
Behavior manifolds
Suppose we have a regression task with 1-dimensional labels and k training examples. Let us take an overparameterized network with N parameters. Every model in parameter space is part of a manifold, where every point on that manifold has identical behavior on the training set. These manifolds are usually[2] at least N−k dimensional, but some are higher dimensional than this. I will call these manifolds “behavior manifolds”, since points on the same manifold have the same behavior (on the training set, not on all possible inputs).
We can visualize the existence of “behavior manifolds” by starting with a blank parameter space, then adding contour planes for each training example. Before we add any contour planes, the entire parameter space is a single manifold, with “identical behavior” on the null set. First, let us add the contour planes for input 1:
Each plane here is an n-1 dimensional manifold, where every model on that plane has the same output on input 1. They slice parameter space into n-1 dimensional regions. Each of these regions is an equivalence class of functions, which all behave about the same on input 1.
Next, we can add contour planes for input 2:
When we put them together, they look like this:
Together, the contours slice parameter space into n-2 dimensional regions. Each “diamond” in the picture is the cross-section of a tube-like region which extends vertically, in the direction which is parallel to both sets of planes. The manifolds of constant behavior are lines which run vertically through these tubes, parallel to both sets of contours.
In higher dimensions, these “lines” and “tubes” are actually n-2 dimensional hyperplanes, since only two degrees of freedom have been removed, one by each set of contours.
We can continue this with more and more inputs. Each input adds another set of hyperplanes, and subtracts one more dimension from the identical-behavior manifolds. Since each input can only slice off one dimension, the manifolds of constant behavior are at least n-k dimensional, where k is the number of training examples.[3]
Solution manifolds
Global minima also lie on behavior manifolds, such that every point on the manifold is a global minimum. I will call these “solution manifolds”. These manifolds generally extend out to infinity, so it isn’t really meaningful to talk about literal “basin volume”.[4] We can focus instead on their dimensionality. All else being equal, a higher dimensional solution manifold should drain a larger region of parameter space, and thus be favored by the inductive bias.[5]
Parallel contours allow higher manifold dimension
Suppose we have 3 parameters (one is off-the-page) and 2 inputs. If the contours are perpendicular:
Then the green regions are cross-sections of tubes extending infinitely off-the-page, where each tube contains models that are roughly equivalent on the training set. The behavior manifolds are lines (1d) running in the out-of-page direction. The black dots are the cross-sections of these lines.
However, if the contours are parallel:
Now the behavior manifolds are planes, running parallel to the contours. So we see here that parallel contours allow behavioral manifolds to have dimension>N−k.
In the next section, I will establish the following fact:
Key result: If a behavioral manifold is more than N−k dimensional, then the normal vectors of the contours must be linearly dependent. The degree of linear independence (the dimensionality of the span) controls the allowed manifold dimensionality.
Behavioral gradients
The normal vector of a contour for input xi is the gradient of the network’s output on that input. If we denote the network output as f(θ,xi), then the normal vector is gi=∇θf(θ,xi). I will call this vector the behavioral gradient, to distinguish it from the gradient of loss.
We can put these behavioral gradients into a matrix G, the matrix of behavioral gradients. The ith column of G is the ith behavioral gradient gi=∇θf(θ,xi).[6]
Now the rank of G is the span of the behavioral gradients. If they are all parallel, Rank(G)=1. If they are all linearly independent, Rank(G)=k, where k is the number of inputs.
Claim 1: The space spanned by the behavioral gradients at a point is perpendicular to the behavioral manifold at that point.[7]
Proof sketch:
The behavioral gradients tell you the first-order sensitivity of the outputs to parameter movement
If you move a small distance ds in parallel to the manifold, then your distance from the manifold goes as ds2
So the change in output also goes as ds2
So the output is only second-order sensitive to movement in this direction
So none of the behavioral gradients have a component in this direction
Therefore the two spaces are perpendicular.
Claim 2: dim(manifold)≤N−dim(span(g1,..,gk))=N−Rank(G)
The first part follows trivially from Claim 1, since two orthogonal spaces in RN cannot have their dimensions sum to more than N. The second part is true by definition of G.
So we have our key result: If dim(manifold)>N−k, then Rank(G)<k, meaning that the behavioral gradients are not linearly independent. The more linearly dependent they are, the lower Rank(G) is, and the higher dim(manifold) is allowed to be.
High manifold dimension ≈ Low-rank G = Linear dependence of behavioral gradients
Claim 3: At a local minimum, Rank(Hessian(Loss))=Rank(G).
The purpose of this claim is to connect our formalism with the Hessian of loss, which is used as a measure of basin sharpness. In qualitative terms:
Flat basin ≈ Low-rank Hessian = Low-rank G ≈ High manifold dimension
Proof sketch for Claim 3:
At a local minimum, first-order sensitivity of behavior translates to second-order sensitivity of loss.
So span(g1,..,gk)⊥ is the null space of the Hessian.
So rank(Hessian)=N−(N−rank(G))=rank(G)
See this footnote[9] for an different proof sketch, which includes the result Hessian=2GGT.[10]
Low rank G indicates information loss
Brief summary:
A network is said to “throw away” information distinguishing a set of m inputs if their activations are identical at some intermediate layer L. When this happens, the behavioral gradients for the two inputs are identical in all layers after L. This greatly increases the chance of linear dependence, since the gradients can now only differ before layer L.
If the number of parameters before L is less than m, then there is guaranteed to be linear dependence. Destruction of information by RELUs often zeros out gradients before L as well, making the effect even stronger.
Hence, information loss often leads to linear dependence of behavioral gradients, which in turn causes low Hessian rank, basin flatness, and high manifold dimension.[11]
For a more detailed explanation, including a case study, see this video presentation:
Follow-up question and extra stuff:
Empirical results and immediate next steps
I am currently running experiments and further theoretical analysis to understand the following:
Is manifold dimensionality actually a good predictor of which solution will be found?
Are {info loss / Hessian rank} and {manifold dimension} related to circuit complexity? In what way? Which one is more related?
So far the results have been somewhat surprising. Specifically, circuit complexity and manifold dimension do not seem very predictive of which solution will be found in very small networks (~20 params / 4 data pts). I expect my understanding to change a lot over the next week.
Update (5/19): Further experiments on ultra-small datasets indicate that the more overparameterized the network is, the less likely we are to find a solution with non-full rank Hessian. My guess is that this is due to increased availability of less-flat basins. Yet the generalization behavior becomes more consistent across runs, not less, and converges to something which looks very natural but can’t be modeled by any simple circuit. I think this is related to the infinite-width / NTK stuff. I am currently quite confused.
I first thought that circuit simplicity was the direct cause of flat basins. I later thought that it was indirectly associated with flat basins due to a close correlation with information loss.
However, recent experiments have updated me towards seeing circuit complexity as much less predictive than I expected, and with a much looser connection to info loss and basin flatness. I am very uncertain about all this, and expect to have a clearer picture in a week or two.
Technically, there can be lower dimensional manifolds than this, but they account for 0% of the hypervolume of parameter space. Whereas manifold classes of N−k≤dim≤N can all have non-zero amounts of hypervolume.
Technically, you can also get manifolds with dim<N−k. For instance, suppose that the contours for input 1 are concentric spheres centered at (-1, 0, 0), and the contours for input 2 are spheres centered at (1, 0, 0). Then all points on the x-axis are unique in behavior, so they are on 0-dimensional manifolds, instead of the expected N−k=1.
This kind of thing usually occurs somewhere in parameter space, but I will treat it as an edge case. The regions with this phenomenon always have vanishing measure.
This can be repaired by using “initialization-weighted” volumes. L2 and L1 regularization also fix the problem; adding a regularization term to the loss shifts the solution manifolds and can collapse them to points.
Empirically, this effect is not at all dominant in very small networks, for reasons currently unknown to me.
In standard terminology, GT is the Jacobian of the concatenation of all outputs, w.r.t the parameters.
Note: This previously incorrectly said G. Thanks to Spencer Becker-Kahn for pointing out that the Jacobian is GT.
I will assume that the manifold, behavior, and loss are differentiable at the point we are examining. Nothing here makes sense at sharp corners.
The orthogonal complement of span(g1,..,gk)
Let H() denote taking the Hessian w.r.t. parameters θ. Inputs are xi, labels are yi, network output is f(θ,xi). L(θ) is loss over all training inputs, and L(θ,xi) is loss on a particular input.
For simplicity, we will center everything at the local min such that θ=0 at the local min and L(θlocal min)=L(0)=0. Assume MSE loss.
Let gi=∇θf(θ,xi) be the ith behavioral gradient.
H(L(θ,xi))v=2gi(gi⋅v)=2gigTiv (For any vector v)
H(L(θ,xi))=2gigTi
H(L(θ))=∑iH(L(θ,xi))=2GGT
Rank(H(L(θ)))=Rank(GGT)=Rank(G)
(Since Rank(AAT)=Rank(A) for any real-valued matrix A)
Assuming MSE loss; the constant will change otherwise.
High manifold dimension does not necessarily follow from the others, since they only bound it on one side, but it often does.