Why Gradients Vanish and Explode
Epistemic status: Confused, but trying to explain a concept that I previously thought I understood. I suspect much of what I wrote below is false.
Without taking proper care of a very deep neural network, gradients tend to suddenly become quite large or quite small. If the gradient is too large, then the network parameters will be thrown completely off, possibly causing them to become NaN. If they are too small, then the network will stop training entirely. This problem is called the vanishing and exploding gradients problem.
When I first learned about the vanishing gradients problem, I ended up getting a vague sense of why it occurs. In my head I visualized the sigmoid function.
I then imagined this being applied element-wise to an affine transformation. If we just look at one element, then we can imagine it being the result of a dot product of some parameters, and that number is being plugged in on the x-axis. On the far left and on the far right, the derivative of this function is very small. This means that if we take the partial derivative with respect to some parameter, it will end up being extremely (perhaps vanishingly) small.
Now, I know the way that I was visualizing this was very wrong. There are a few mistakes I made:
1. This picture doesn’t tell me anything about why the gradient “vanishes.” It’s just showing me a picture of where the gradients get small. Gradients also get small when they reach a local minimum. Does this mean that vanishing gradients are sometimes good?
2. I knew that gradient vanishing had something to do with the depth of a network, but I didn’t see how the network being deep affected why the gradients got small. I had a rudimentary sense that each layer of sigmoid compounds the problem until there’s no gradient left, but this was never presented to me in a precise way, so I just ignored it.
I now think I understand the problem a bit better, but maybe not a whole lot better.
(Note: I have gathered evidence that the vanishing gradient problem is not linked to sigmoids and put it in this comment. I will be glad to see evidence which proves I’m wrong on this one, but I currently believe this is evidence that machine learning professors are teaching it incorrectly).
First, the basics. Without describing the problem in a very general sense, I’ll walk through a brief example. In particular, I’ll show how we can imagine a forward pass in a simple recurrent neural network that enables a feedback effect to occur. We can then immediately see how gradient vanishing can become a problem within this framework (no sigmoids necessary).
Imagine that there is some sequence of vectors which are defined via the following recursive definition,
This sequence of vectors can be identified as the sequence of hidden states of the network. Let admit an orthogonal eigendecomposition. We can then represent this repeated application of the weights matrix as
where is a diagonal matrix containing the eigenvalues of , and is an orthogonal matrix. If we consider the eigenvalues, which are the diagonal entries of , we can tell that the ones that are less than one will decay exponentially towards zero, and the values that are greater than one will blow up exponentially towards infinity as grows in size.
Since is orthogonal, the transformation can be thought of as a rotation transformation of the vector where each coordinate in the new transformation reflects being projected onto an eigenvector of . Therefore, when is very large, as in the case of an unrolled recurrent network, then this matrix calculation will end up getting dominated by the parts of that point in the same direction as the exploding eigenvectors.
This is a problem because if an input vector ends up pointing in the direction of one of these eigenvectors, the loss function may be very high. From this, it will turn out that in these regions, stochastic gradient descent may massively overshoot. If SDG overshoots, then we end up reversing all of the descent progress that we previously had towards descending down to a local minimum.
As Goodfellow et al. note, this error is relatively easy to avoid in the case of non-recurrent neural networks, because in that case the weights aren’t shared between layers. However, in the case of vanilla recurrent neutral networks, this problem is almost unavoidable. Bengio et al. showed that in cases where a simple neural network is even a depth of 10, this problem will show up with near certainty.
One way to help the problem is by simply clipping the gradients so that they can’t reverse all of the descent progress so far. This helps the symptom of exploding gradients, but doesn’t fix the problem entirely, since the issue with blown up or vanishing eigenvalues remains.
Therefore, in order to fix this problem, we need to fundamentally re-design the way that the gradients are backpropagated through time, motivating echo state networks, leaky units, skip connections, and LSTMs. I plan to one day go into all of these, but I first need to build up my skills in matrix calculus, which are currently quite poor.
Therefore, I intend to make the next post (and maybe a few more) about matrix calculus. Then perhaps I can revisit this topic and gain a deeper understanding.
This may be an idiosyncratic error of mine. See page 105 in these lecture notes to see where I first saw the problem of vanishing gradients described.
See section 10.7 in the Deep Learning Book for a fuller discussion of vanishing and exploding gradients.
Yay for learning matrix calculus! I’m eager to read and learn. Personally I’ve done very well in the class where we learned it, but I’d say I didn’t get it at a deep / useful level.
Great! I’ll do my best to keep the post as informative as possible, and I’ll try to get into it on a deep level.
If you’re looking to improve your matrix calculus skills, I specifically recommend practicing tensor index notation and the Einstein summation convention. It will make neural networks much more pleasant, especially recurrent nets. (This may have been obvious already, but it’s sometimes tough to tell what’s useful when learning a subject.)
I think the problem with vanishing gradients is usually linked to repeated applications of the sigmoid activation function. The gradient in backpropagation is calculated from the chain rule, where each factor d\sigma/dz in the “chain” will always be less than zero, and close to zero for large or small inputs. So for feed-forward network, the problem is a little different from recurrent networks, which you describe.
The usual mitigation is to use ReLU activations, L2 regularization, and/or batch normalization.
A minor point: the gradient doesn’t necessarily tend towards zero as you get closer to a local minimum, that depends on the higher order derivatives. Imagine a local minimum at the bottom of a funnel or spike, for instance—or a very spiky fractal-like landscape. On the other hand, a local minimum in a region with a small gradient is a desirable property, since it means small perturbations in the input data doesn’t change the output much. But this point will be difficult to reach, since learning depends on the gradient...
(Thanks for the interesting analysis, I’m happy to discuss this but probably won’t drop by regularly to check comments—feel free to email me at ketil at malde point org)
That’s what I used to think too. :)
If you look at the post above, I even linked to the reason why I thought that. In particular, vanishing gradients was taught as intrinsically related to the sigmoid function in page 105 in these lecture notes, which is where I initially learned about the problem.
However, I no longer think gradient vanishing is fundamentally linked to sigmoids or tanh activations.
I think that there is probably some confusion in terminology, and some people use the the words differently than others. If we look in the Deep Learning Book, there are two sections that talk about the problem, namely section 8.2.5 and section 10.7, neither of which bring up sigmoids as being related (though they do bring up deep weight sharing networks). Goodfellow et al. cite Sepp Hochreiter’s 1991 thesis as being the original document describing the issue, but unfortunately it’s in German so I cannot comment as to whether it links the issue to sigmoids.
Currently, when I Ctrl-F “sigmoid” on the Wikipedia page for vanishing gradients, there are no mentions. There is a single subheader which states, “Rectifiers such as ReLU suffer less from the vanishing gradient problem, because they only saturate in one direction.” However, the citation for this statement comes from this paper which mentions vanishing gradients only once and explicitly states,
(Note: I misread the quote above—I’m still confused).
I think this is quite strong evidence that I was not taught the correct usage of vanishing gradients.
Interesting you say that. I actually wrote a post on rethinking batch normalization, and I no longer think it’s justified to say that batch normalization simply mitigates vanishing gradients. The exact way that batch normalization works is a bit different, and it would be inaccurate to describe it as an explicit strategy to reduce vanishing gradients (although it may help. Funny enough the original batch normalization paper says that with batchnorm they were able to train with sigmoids easier).
True. I had a sort of smooth loss function in my head.
I’m very confused. The way I’m reading the quote you provided, it says ReLu works better because it doesn’t have the gradient vanishing effect that sigmoid and tanh have.
Interesting. I just re-read it and you are completely right. Well I wonder how that interacts with what I said above.
That proof of the instability of RNNs is very nice.
The version of the vanishing gradient problem I learned is simply that if you’re updating weights proportional to the gradient, then if your average weight somehow ends up as 0.98, as you increase the number of layers your gradient, and therefore your update size, will shrink kind of like (0.98)^n, which is not the behavior you want it to have.
Great, thanks. It is adapted from Goodfellow et al.’s discussion of the topic, which I cite in the post.
That makes sense. However, Goodfellow et al. argue that this isn’t a big issue for non-RNNs. Their discussion is a bit confusing to me so I’ll just leave it below,