That’s useful to know, thanks. Is anything else known about the properties of the noise covariance beyond “it’s not constant”?
Some comments on the paper itself: if the problem is that SGD with homoskedastic Gaussian noise fails to converge to a stationary distribution, why don’t they define SGD over a torus instead? Seems like it would fix the problem they are talking about, and if it doesn’t change the behavior it means their explanation of what’s going on is incorrect.
If the only problem is that with homoskedastic Gaussian noise convergence to a stationary distribution is slow (when a stationary distribution does exist), I could believe that. Similar algorithms such as Metropolis-Hastings also have pretty abysmal convergence rates in practice when applied to any kind of complicated problem. It’s possible that SGD with batch noise has better regularization properties and therefore converges faster, but I don’t think that changes the basic qualitative picture I present in the post.
Some comments on the paper itself: if the problem is that SGD with homoskedastic Gaussian noise fails to converge to a stationary distribution, why don’t they define SGD over a torus instead?
Good question. I imagine that would work but it would converge more slowly. I think a more important issue is that the homoskedastic/heteroskedastic noise cases would have different equilibrium distributions even if both existed(they don’t say this but it seems intuitively obvious since there would be a pressure away from points with higher noise in the heteroskedastic case). I guess on the torus this would correspond to there being a large number of bad minima which dominate the equilibrium in the homoskedastic case.
Generally speaking the SGD noise seems to provide a regularizing effect towards ‘flatter’ solutions. The beginning of this paper has a good overview.
As an aside, I’ve tried to work out what the optimal learning rate for a large language model should be based on the theory in the post, and if I’m doing the calculations correctly (which is a pretty big if) it doesn’t match actual practice very well, suggesting there is actually something important missing from this picture.
Essentially, the coefficient β should be 2/σ2 where σ2 is the variance of the per-parameter noise in SGD. If you have a learning rate η, you scale the objective you’re optimizing by a factor η and the noise variance by a factor η2. Likewise, a bigger batch size nb lowers the noise variance by a linear factor. So the equilibrium distribution ends up proportional to
exp(−2Lσ2⋅nbη)
where L is the per-token average loss and σ2 should be equal to the mean square of the partial derivative of the per-token loss function with respect to one of the neural network parameters. If the network is using some decent batch or layer normalization this should probably be O(1/N) where N is the model size.
We want what’s inside the exponential to just be L⋅D, because we want the learning to be equivalent to doing a Bayesian update over the whole data. This suggests we should pick
η=O(2NnbD)
which is a pretty bad prediction. So there’s probably something important that’s being left out of this model. I’m guessing that a smaller learning rate just means you end up conditioning on minimum loss and that’s all you need to in practice, and larger learning rates cause problems with convergence.
That’s useful to know, thanks. Is anything else known about the properties of the noise covariance beyond “it’s not constant”?
Some comments on the paper itself: if the problem is that SGD with homoskedastic Gaussian noise fails to converge to a stationary distribution, why don’t they define SGD over a torus instead? Seems like it would fix the problem they are talking about, and if it doesn’t change the behavior it means their explanation of what’s going on is incorrect.
If the only problem is that with homoskedastic Gaussian noise convergence to a stationary distribution is slow (when a stationary distribution does exist), I could believe that. Similar algorithms such as Metropolis-Hastings also have pretty abysmal convergence rates in practice when applied to any kind of complicated problem. It’s possible that SGD with batch noise has better regularization properties and therefore converges faster, but I don’t think that changes the basic qualitative picture I present in the post.
Good question. I imagine that would work but it would converge more slowly. I think a more important issue is that the homoskedastic/heteroskedastic noise cases would have different equilibrium distributions even if both existed(they don’t say this but it seems intuitively obvious since there would be a pressure away from points with higher noise in the heteroskedastic case). I guess on the torus this would correspond to there being a large number of bad minima which dominate the equilibrium in the homoskedastic case.
Generally speaking the SGD noise seems to provide a regularizing effect towards ‘flatter’ solutions. The beginning of this paper has a good overview.
As an aside, I’ve tried to work out what the optimal learning rate for a large language model should be based on the theory in the post, and if I’m doing the calculations correctly (which is a pretty big if) it doesn’t match actual practice very well, suggesting there is actually something important missing from this picture.
Essentially, the coefficient β should be 2/σ2 where σ2 is the variance of the per-parameter noise in SGD. If you have a learning rate η, you scale the objective you’re optimizing by a factor η and the noise variance by a factor η2. Likewise, a bigger batch size nb lowers the noise variance by a linear factor. So the equilibrium distribution ends up proportional to
exp(−2Lσ2⋅nbη)
where L is the per-token average loss and σ2 should be equal to the mean square of the partial derivative of the per-token loss function with respect to one of the neural network parameters. If the network is using some decent batch or layer normalization this should probably be O(1/N) where N is the model size.
We want what’s inside the exponential to just be L⋅D, because we want the learning to be equivalent to doing a Bayesian update over the whole data. This suggests we should pick
η=O(2NnbD)
which is a pretty bad prediction. So there’s probably something important that’s being left out of this model. I’m guessing that a smaller learning rate just means you end up conditioning on minimum loss and that’s all you need to in practice, and larger learning rates cause problems with convergence.