Some comments on the paper itself: if the problem is that SGD with homoskedastic Gaussian noise fails to converge to a stationary distribution, why don’t they define SGD over a torus instead?
Good question. I imagine that would work but it would converge more slowly. I think a more important issue is that the homoskedastic/heteroskedastic noise cases would have different equilibrium distributions even if both existed(they don’t say this but it seems intuitively obvious since there would be a pressure away from points with higher noise in the heteroskedastic case). I guess on the torus this would correspond to there being a large number of bad minima which dominate the equilibrium in the homoskedastic case.
Generally speaking the SGD noise seems to provide a regularizing effect towards ‘flatter’ solutions. The beginning of this paper has a good overview.
As an aside, I’ve tried to work out what the optimal learning rate for a large language model should be based on the theory in the post, and if I’m doing the calculations correctly (which is a pretty big if) it doesn’t match actual practice very well, suggesting there is actually something important missing from this picture.
Essentially, the coefficient β should be 2/σ2 where σ2 is the variance of the per-parameter noise in SGD. If you have a learning rate η, you scale the objective you’re optimizing by a factor η and the noise variance by a factor η2. Likewise, a bigger batch size nb lowers the noise variance by a linear factor. So the equilibrium distribution ends up proportional to
exp(−2Lσ2⋅nbη)
where L is the per-token average loss and σ2 should be equal to the mean square of the partial derivative of the per-token loss function with respect to one of the neural network parameters. If the network is using some decent batch or layer normalization this should probably be O(1/N) where N is the model size.
We want what’s inside the exponential to just be L⋅D, because we want the learning to be equivalent to doing a Bayesian update over the whole data. This suggests we should pick
η=O(2NnbD)
which is a pretty bad prediction. So there’s probably something important that’s being left out of this model. I’m guessing that a smaller learning rate just means you end up conditioning on minimum loss and that’s all you need to in practice, and larger learning rates cause problems with convergence.
Good question. I imagine that would work but it would converge more slowly. I think a more important issue is that the homoskedastic/heteroskedastic noise cases would have different equilibrium distributions even if both existed(they don’t say this but it seems intuitively obvious since there would be a pressure away from points with higher noise in the heteroskedastic case). I guess on the torus this would correspond to there being a large number of bad minima which dominate the equilibrium in the homoskedastic case.
Generally speaking the SGD noise seems to provide a regularizing effect towards ‘flatter’ solutions. The beginning of this paper has a good overview.
As an aside, I’ve tried to work out what the optimal learning rate for a large language model should be based on the theory in the post, and if I’m doing the calculations correctly (which is a pretty big if) it doesn’t match actual practice very well, suggesting there is actually something important missing from this picture.
Essentially, the coefficient β should be 2/σ2 where σ2 is the variance of the per-parameter noise in SGD. If you have a learning rate η, you scale the objective you’re optimizing by a factor η and the noise variance by a factor η2. Likewise, a bigger batch size nb lowers the noise variance by a linear factor. So the equilibrium distribution ends up proportional to
exp(−2Lσ2⋅nbη)
where L is the per-token average loss and σ2 should be equal to the mean square of the partial derivative of the per-token loss function with respect to one of the neural network parameters. If the network is using some decent batch or layer normalization this should probably be O(1/N) where N is the model size.
We want what’s inside the exponential to just be L⋅D, because we want the learning to be equivalent to doing a Bayesian update over the whole data. This suggests we should pick
η=O(2NnbD)
which is a pretty bad prediction. So there’s probably something important that’s being left out of this model. I’m guessing that a smaller learning rate just means you end up conditioning on minimum loss and that’s all you need to in practice, and larger learning rates cause problems with convergence.