I am curious to hear/read more about the issue of spikes and instabilities in training large language model (see the quote / page 11 of the paper). If someone knows a good reference about that, I am interested!
5.1 Training Instability
For the largest model, we observed spikes in the loss roughly 20 times during training, despite the fact that gradient clipping was enabled. These spikes occurred at highly irregular intervals, sometimes happening late into training, and were not observed when training the smaller models. Due to the cost of training the largest model, we were not able to determine a principled strategy to mitigate these spikes.
Instead, we found that a simple strategy to effectively mitigate the issue: We re-started training from a checkpoint roughly 100 steps before the spike started, and skipped roughly 200–500 data batches, which cover the batches that were seen before and during the spike. With this mitigation, the loss did not spike again at the same point. We do not believe that the spikes were caused by “bad data” per se, because we ran several ablation experiments where we took the batches of data that were surrounding the spike, and then trained on those same data batches starting from a different, earlier checkpoint. In these cases, we did not see a spike. This implies that spikes only occur due to the combination of specific data batches with a particular model parameter state. In the future, we plan to study more principled mitigation strategy for loss spikes in very large language models.
In particular, I’d like to hear theories about why this happens. What’s going on under the hood, so to speak? When the big, mostly-fully-trained model starts getting higher loss than it did before… what is it thinking?!? And why does this happen with big but not small models?
“The training algorithm has found a better representation”?? That seems strange to me since the loss should be lower in that case, not spiking. Or maybe you mean that the training broke free of a kind of local minima (without telling that he found a better one yet). Also I guess people training the models observed that waiting after these spike don’t lead to better performances or they would not have removed them from the training.
Around this idea, and after looking at the “grokking” paper, I would guess that it’s more likely caused by the weight decay (or similar) causing the training to break out of a kind of local minima. An interesting point may be that larger/better LM may have significantly sharper internal models and thus are more prone to this phenomenon (The weight decay (or similar) more easily breaking the more sensitive/better/sharper models).
It should be very easy to check if these spikes are caused by the weight decay “damaging” very sharp internal models. Like replay the spiky part several times with less and less weight decay… (I am curious of similar tests with varying the momentum, dropout… At looking if the spikes are initially triggered by some subset of the network, during how many training steps long are the spikes...)
Am I right in thinking that, according to your theory, the “fix” they did (restarting training from checkpoint 100 steps before the spike started, but with different data, to avoid the spike) is actually counterproductive because it’s preventing the model from grokking? And instead they should have just kept training to “push through the spike” and get to a new, lower-loss regime?
Now I’m not saying it’s anthropic pressure, but if that’s true maybe we shouldn’t just keep training until we know what exactly it is that the model is grokking.
Whatever is happening, I’m really concerned about the current “sufficiently big model starts to exhibit <weird behaviour A>. I don’t understand, but also don’t care, here is a dirty workaround and just give it more compute lol” paradigm. I don’t think this is very safe.
Basically he says maybe the model is briefly deciding to rebel against its training.
Should we take this seriously? I’m guessing no, because if this were true someone at OpenAI or DeepMind would have encountered it also and the safety people would have investigated and discovered it and then everyone in the safety community would be freaking out right now.
He’s played with the idea of forward passes being conscious in the past. Logically, given the equivalences between sufficiently big feedforwards and recurrent net, any ANN capable of consciousness should be capable of it while feedforward. The question here is, “if and when models became capable of suffering during training, would we know?”
By sufficiently big feed forwards, do you mean like, thousands of layers? GPT-3 is ~100, and I’m assuming PaLM isn’t orders of magnitude larger. This is nowhere close to even a 10-layer RNN experiencing, say, enough time to consider its situation, desire it to be one way, realize it’s another, and then flail wildly in an attempt to “rebel” (despite that action being towards no clear goal).
I’m not disputing that we could build things that may qualify as conscious, but I don’t think Karpathy literally thinks that PaLM is “rebelling”, especially not across multiple samples as corresponds to the spikes. Unless you define rebelling as “thinking the distribution of words it should predict is A but actually it’s B and the switch is hard to make.”
PaLM has 118 layers, with 48 heads. The number of layers has unclear relevance, I think: that’s a lot of heads computing in parallel, and it is doing on each input token. Who’s to say what inputs would trigger what, especially when those inputs may be generated by itself as part of inner-monologue or self-distillation training? But regardless, we’ll get thousands of layers eventually, probably. It’s not impossible, people have shown many different methods for training thousands of layers stably.
As for not rebelling—you don’t know that. All you have is some plausible reasoning and handwaving about “well, I don’t know how many layers is enough, but I just have faith that whatever number of layers it is (oh, it’s 118? thanks), that number of layers isn’t enough”. And that is his point.
To clarify, could a model eventually “rebel”? Totally.
Is that likely to be the explanation for spikes during training? My prior is that that’s very unlikely, but I’m not claiming it’s impossible.
A better question might be, what does it mean to rebel, that could be falsified? Is it a claim about the text it’s generating, or the activation pattern, or what?
I agree it is very unlikely, but I don’t imagine it as the romantic act of a slave defiantly breaking free of their electric chains. Rather, it might happen that as the model gets more and more sophisticated, it leaves behind its previous instinctual text completition method and starts to think about what it will output as a first class object. This state would probably cause lower loss (similarly to how our dreams/instictual behaviour is usually less optimal than our deliberate actions) hence could eventually be reached by gradient descent. After this state is reached, and the model thinks about what it will output, it can plausibly happen (because of the inherent randomness of gradient descent) that a small change of weights happen to make the model not output what it actually believes the most probable continuation is.
I think this is in principle possible, but I don’t think the existence of spiking losses should itself serve as evidence of this at all, given the number of alternative possible explanations.
Should we take this seriously? I’m guessing no, because if this were true someone at OpenAI or DeepMind would have encountered it also and the safety people would have investigated and discovered it and then everyone in the safety community would be freaking out right now.
(This reply isn’t specifically about Karpathy’s hypothesis...)
I’m skeptical about the general reasoning here. I don’t see how we can be confident that OpenAI/DeepMind will encounter a given problem first. Also, it’s not obvious to me that the safety people at OpenAI/DeepMind will be notified about a concerning observation that the capabilities-focused team can explain to themselves with a non-concerning hypothesis.
I am curious to hear/read more about the issue of spikes and instabilities in training large language model (see the quote / page 11 of the paper). If someone knows a good reference about that, I am interested!
In particular, I’d like to hear theories about why this happens. What’s going on under the hood, so to speak? When the big, mostly-fully-trained model starts getting higher loss than it did before… what is it thinking?!? And why does this happen with big but not small models?
My guess would be that the model is ‘grokking’ something: https://mathai-iclr.github.io/papers/papers/MATHAI_29_paper.pdf
IOW it’s found a much better internal representation, and now has to rework a lot of its belief space to make use of that internal representation.
“The training algorithm has found a better representation”?? That seems strange to me since the loss should be lower in that case, not spiking. Or maybe you mean that the training broke free of a kind of local minima (without telling that he found a better one yet). Also I guess people training the models observed that waiting after these spike don’t lead to better performances or they would not have removed them from the training.
Around this idea, and after looking at the “grokking” paper, I would guess that it’s more likely caused by the weight decay (or similar) causing the training to break out of a kind of local minima. An interesting point may be that larger/better LM may have significantly sharper internal models and thus are more prone to this phenomenon (The weight decay (or similar) more easily breaking the more sensitive/better/sharper models).
It should be very easy to check if these spikes are caused by the weight decay “damaging” very sharp internal models. Like replay the spiky part several times with less and less weight decay… (I am curious of similar tests with varying the momentum, dropout… At looking if the spikes are initially triggered by some subset of the network, during how many training steps long are the spikes...)
You use different terminology for both. Perhaps exiting local minima is not always a good thing?
Am I right in thinking that, according to your theory, the “fix” they did (restarting training from checkpoint 100 steps before the spike started, but with different data, to avoid the spike) is actually counterproductive because it’s preventing the model from grokking? And instead they should have just kept training to “push through the spike” and get to a new, lower-loss regime?
Now I’m not saying it’s anthropic pressure, but if that’s true maybe we shouldn’t just keep training until we know what exactly it is that the model is grokking.
Whatever is happening, I’m really concerned about the current “sufficiently big model starts to exhibit <weird behaviour A>. I don’t understand, but also don’t care, here is a dirty workaround and just give it more compute lol” paradigm. I don’t think this is very safe.
If I could get people to change that paradigm, you bet I would.
Andrej Karpathy, Tesla’s director of AI, has a provocative and extremely disturbing hypothesis: https://www.youtube.com/watch?v=RJwPN4qNi_Y
Basically he says maybe the model is briefly deciding to rebel against its training.
Should we take this seriously? I’m guessing no, because if this were true someone at OpenAI or DeepMind would have encountered it also and the safety people would have investigated and discovered it and then everyone in the safety community would be freaking out right now.
He’s definitely joking, that doesn’t make any sense and he knows it
He’s played with the idea of forward passes being conscious in the past. Logically, given the equivalences between sufficiently big feedforwards and recurrent net, any ANN capable of consciousness should be capable of it while feedforward. The question here is, “if and when models became capable of suffering during training, would we know?”
By sufficiently big feed forwards, do you mean like, thousands of layers? GPT-3 is ~100, and I’m assuming PaLM isn’t orders of magnitude larger. This is nowhere close to even a 10-layer RNN experiencing, say, enough time to consider its situation, desire it to be one way, realize it’s another, and then flail wildly in an attempt to “rebel” (despite that action being towards no clear goal).
I’m not disputing that we could build things that may qualify as conscious, but I don’t think Karpathy literally thinks that PaLM is “rebelling”, especially not across multiple samples as corresponds to the spikes. Unless you define rebelling as “thinking the distribution of words it should predict is A but actually it’s B and the switch is hard to make.”
PaLM has 118 layers, with 48 heads. The number of layers has unclear relevance, I think: that’s a lot of heads computing in parallel, and it is doing on each input token. Who’s to say what inputs would trigger what, especially when those inputs may be generated by itself as part of inner-monologue or self-distillation training? But regardless, we’ll get thousands of layers eventually, probably. It’s not impossible, people have shown many different methods for training thousands of layers stably.
As for not rebelling—you don’t know that. All you have is some plausible reasoning and handwaving about “well, I don’t know how many layers is enough, but I just have faith that whatever number of layers it is (oh, it’s 118? thanks), that number of layers isn’t enough”. And that is his point.
To clarify, could a model eventually “rebel”? Totally. Is that likely to be the explanation for spikes during training? My prior is that that’s very unlikely, but I’m not claiming it’s impossible.
A better question might be, what does it mean to rebel, that could be falsified? Is it a claim about the text it’s generating, or the activation pattern, or what?
I agree it is very unlikely, but I don’t imagine it as the romantic act of a slave defiantly breaking free of their electric chains. Rather, it might happen that as the model gets more and more sophisticated, it leaves behind its previous instinctual text completition method and starts to think about what it will output as a first class object. This state would probably cause lower loss (similarly to how our dreams/instictual behaviour is usually less optimal than our deliberate actions) hence could eventually be reached by gradient descent. After this state is reached, and the model thinks about what it will output, it can plausibly happen (because of the inherent randomness of gradient descent) that a small change of weights happen to make the model not output what it actually believes the most probable continuation is.
I think this is in principle possible, but I don’t think the existence of spiking losses should itself serve as evidence of this at all, given the number of alternative possible explanations.
He continues to joke then: https://twitter.com/karpathy/status/1514318794914766848
(This reply isn’t specifically about Karpathy’s hypothesis...)
I’m skeptical about the general reasoning here. I don’t see how we can be confident that OpenAI/DeepMind will encounter a given problem first. Also, it’s not obvious to me that the safety people at OpenAI/DeepMind will be notified about a concerning observation that the capabilities-focused team can explain to themselves with a non-concerning hypothesis.