To make a Chinchilla optimal model smaller while maintaining its capabilities, you need more data. At 15T tokens (the amount of data used in Llama 3), a Chinchilla optimal model has 750b active parameters, and training it invests 7e25 FLOPs (Gemini 1.0 Ultra or 4x original GPT-4). A larger $1 billion training run, which might be the current scale that’s not yet deployed, would invest 2e27 FP8 FLOPs if using H100s. A Chinchilla optimal run for these FLOPs would need 80T tokens when using unique data.
Starting with a Chinchilla optimal model, if it’s made 3x smaller, maintaining performance requires training it on 9x more data, so that it needs 3x more compute. That’s already too much data, and we are only talking 3x smaller. So we need ways of stretching the data that is available. By repeating data up to 16 times, it’s possible to make good use of 100x more compute than by only using unique data once. So with say 2e26 FP8 FLOPs (a $100 million training run on H100s), we can train a 3x smaller model that matches performance of the above 7e25 FLOPs Chinchilla optimal model while needing only about 27T tokens of unique data (by repeating them 5 times) instead of 135T unique tokens, and the model will have about 250b active parameters. That’s still a lot of data, and we are only repeating it 5 times where it remains about as useful in training as unique data, while data repeated 16 times (that lets us make use of 100x more compute from repetition) becomes 2-3 times less valuable per token.
There is also distillation, where a model is trained to predict the distribution generated by another model (Gemma-2-9b was trained this way). But this sort of distillation still happens while training on real data, and it only allows to make use of about 2x less data to get similar performance, so it only slightly pushes back the data wall. And rumors of synthetic data for pre-training (as opposed to post-training) remain rumors. With distillation on 16x repeated 50T tokens of unique data, we then get the equivalent of training on 800T tokens of unique data (it gets 2x less useful per token through repetition, but 2x more useful through distillation). This enables reducing active parameters 3x (as above, maintaining performance), compared to a Chinchilla optimal model trained for 80T tokens with 2e27 FLOPs (a $1 billion training run for the Chinchilla optimal model). This overtrained model would cost $3 billion (and have 1300b active parameters).
So the prediction is that the trend for getting models that are both cheaper for inference and smarter might continue into the imminent $1 billion training run regime but will soon sputter out when going further due to the data wall. Overcoming this requires algorithmic progress that’s not currently publicly in evidence, and visible success in overcoming it in deployed models will be evidence of such algorithmic progress within LLM labs. But Chinchilla optimal models (with corrections for inefficiency of repeated data) can usefully scale to at least 8e28 FLOPs ($40 billion in cost of time, 6 gigawatts) with mere 50T tokens of unique data.
Edit (20 Jul): These estimates erroneously use the sparse FP8 tensor performance for H100s (4 petaFLOP/s), which is 2 times higher than far more relevant dense FP8 tensor performance (2 petaFLOP/s). But with a Blackwell GPU, the relevant dense FP8 performance is 5 petaFLOP/s, which is close to 4 petaFLOP/s, and the cost and power per GPU within a rack are also similar. So the estimates approximately work out unchanged when reading “Blackwell GPU” instead of “H100″.
One question: Do you think Chinchilla scaling laws are still correct today, or are they not? I would assume these scaling laws depend on the data set used in training, so that if OpenAI found/created a better data set, this might change scaling laws.
Do you agree with this, or do you think it’s false?
Data varies in the loss it enables, doesn’t seem to vary greatly in the ratio between the number of tokens and the number of parameters that extracts the best loss out of training with given compute. That is, I’m usually keeping this question in mind, didn’t see evidence to the contrary in the papers, but relevant measurements are very rarely reported, even in model series training report papers where the ablations were probably actually done. So could be very wrong, generalization from 2.5 examples. With repetition, there’s this gradual increase from 20 to 60. Probably something similar is there for distillation (in the opposite direction), but I’m not aware of papers that measure this, so also could be wrong.
One interesting point is the isoFLOP plots in the StripedHyena post (search “Perplexity scaling analysis”). With hybridization where standard attention remains in 8-50% of the blocks, perplexity is quite insensitive to change in model size while keeping compute fixed, while for pure standard attention the penalty for deviating from the optimal ratio to a similar extent is much greater. This suggests that one way out for overtrained models might be hybridization with these attention alternatives. That is, loss for an overtrained model might be closer to Chinchilla optimal loss with a hybrid model than it would be for a similarly overtrained pure standard attention model. Out of the big labs, visible moves in this directions were made by DeepMind with their Griffin Team (Griffin paper, RecurrentGemma). So that’s one way the data wall might get pushed a little further for the overtrained models.
New data! Llama 3 report includes data about Chinchilla optimality study on their setup. The surprise is that Llama 3 405b was chosen to have the optimal size rather than being 2x overtrained. Their actual extrapolation for an optimal point is 402b parameters, 16.55T tokens, and 3.8e25 FLOPs.
Fitting to the tokens per parameter framing, this gives the ratio of 41 (not 20) around the scale of 4e25 FLOPs. More importantly, their fitted dependence of optimal number of tokens on compute has exponent 0.53, compared to 0.51 from the Chinchilla paper (which was almost 0.5, hence tokens being proportional to parameters). Though the data only goes up to 1e22 FLOPs (3e21 FLOPs for Chinchilla), what actually happens at 4e25 FLOPs (6e23 FLOPs for Chinchilla) is all extrapolation, in both cases, there are no isoFLOP plots at those scales. At least Chinchilla has Gopher as a point of comparison, and there was only 200x FLOPs gap in the extrapolation, while for Llama 3 405 the gap is 4000x.
So data needs grow faster than parameters with more compute. This looks bad for the data wall, though the more relevant question is what would happen after 16 repetitions, or how this dependence really works with more FLOPs (with the optimal ratio of tokens to parameters changing with scale).
To make a Chinchilla optimal model smaller while maintaining its capabilities, you need more data. At 15T tokens (the amount of data used in Llama 3), a Chinchilla optimal model has 750b active parameters, and training it invests 7e25 FLOPs (Gemini 1.0 Ultra or 4x original GPT-4). A larger $1 billion training run, which might be the current scale that’s not yet deployed, would invest 2e27 FP8 FLOPs if using H100s. A Chinchilla optimal run for these FLOPs would need 80T tokens when using unique data.
Starting with a Chinchilla optimal model, if it’s made 3x smaller, maintaining performance requires training it on 9x more data, so that it needs 3x more compute. That’s already too much data, and we are only talking 3x smaller. So we need ways of stretching the data that is available. By repeating data up to 16 times, it’s possible to make good use of 100x more compute than by only using unique data once. So with say 2e26 FP8 FLOPs (a $100 million training run on H100s), we can train a 3x smaller model that matches performance of the above 7e25 FLOPs Chinchilla optimal model while needing only about 27T tokens of unique data (by repeating them 5 times) instead of 135T unique tokens, and the model will have about 250b active parameters. That’s still a lot of data, and we are only repeating it 5 times where it remains about as useful in training as unique data, while data repeated 16 times (that lets us make use of 100x more compute from repetition) becomes 2-3 times less valuable per token.
There is also distillation, where a model is trained to predict the distribution generated by another model (Gemma-2-9b was trained this way). But this sort of distillation still happens while training on real data, and it only allows to make use of about 2x less data to get similar performance, so it only slightly pushes back the data wall. And rumors of synthetic data for pre-training (as opposed to post-training) remain rumors. With distillation on 16x repeated 50T tokens of unique data, we then get the equivalent of training on 800T tokens of unique data (it gets 2x less useful per token through repetition, but 2x more useful through distillation). This enables reducing active parameters 3x (as above, maintaining performance), compared to a Chinchilla optimal model trained for 80T tokens with 2e27 FLOPs (a $1 billion training run for the Chinchilla optimal model). This overtrained model would cost $3 billion (and have 1300b active parameters).
So the prediction is that the trend for getting models that are both cheaper for inference and smarter might continue into the imminent $1 billion training run regime but will soon sputter out when going further due to the data wall. Overcoming this requires algorithmic progress that’s not currently publicly in evidence, and visible success in overcoming it in deployed models will be evidence of such algorithmic progress within LLM labs. But Chinchilla optimal models (with corrections for inefficiency of repeated data) can usefully scale to at least 8e28 FLOPs ($40 billion in cost of time, 6 gigawatts) with mere 50T tokens of unique data.
Edit (20 Jul): These estimates erroneously use the sparse FP8 tensor performance for H100s (4 petaFLOP/s), which is 2 times higher than far more relevant dense FP8 tensor performance (2 petaFLOP/s). But with a Blackwell GPU, the relevant dense FP8 performance is 5 petaFLOP/s, which is close to 4 petaFLOP/s, and the cost and power per GPU within a rack are also similar. So the estimates approximately work out unchanged when reading “Blackwell GPU” instead of “H100″.
One question: Do you think Chinchilla scaling laws are still correct today, or are they not? I would assume these scaling laws depend on the data set used in training, so that if OpenAI found/created a better data set, this might change scaling laws.
Do you agree with this, or do you think it’s false?
Data varies in the loss it enables, doesn’t seem to vary greatly in the ratio between the number of tokens and the number of parameters that extracts the best loss out of training with given compute. That is, I’m usually keeping this question in mind, didn’t see evidence to the contrary in the papers, but relevant measurements are very rarely reported, even in model series training report papers where the ablations were probably actually done. So could be very wrong, generalization from 2.5 examples. With repetition, there’s this gradual increase from 20 to 60. Probably something similar is there for distillation (in the opposite direction), but I’m not aware of papers that measure this, so also could be wrong.
One interesting point is the isoFLOP plots in the StripedHyena post (search “Perplexity scaling analysis”). With hybridization where standard attention remains in 8-50% of the blocks, perplexity is quite insensitive to change in model size while keeping compute fixed, while for pure standard attention the penalty for deviating from the optimal ratio to a similar extent is much greater. This suggests that one way out for overtrained models might be hybridization with these attention alternatives. That is, loss for an overtrained model might be closer to Chinchilla optimal loss with a hybrid model than it would be for a similarly overtrained pure standard attention model. Out of the big labs, visible moves in this directions were made by DeepMind with their Griffin Team (Griffin paper, RecurrentGemma). So that’s one way the data wall might get pushed a little further for the overtrained models.
New data! Llama 3 report includes data about Chinchilla optimality study on their setup. The surprise is that Llama 3 405b was chosen to have the optimal size rather than being 2x overtrained. Their actual extrapolation for an optimal point is 402b parameters, 16.55T tokens, and 3.8e25 FLOPs.
Fitting to the tokens per parameter framing, this gives the ratio of 41 (not 20) around the scale of 4e25 FLOPs. More importantly, their fitted dependence of optimal number of tokens on compute has exponent 0.53, compared to 0.51 from the Chinchilla paper (which was almost 0.5, hence tokens being proportional to parameters). Though the data only goes up to 1e22 FLOPs (3e21 FLOPs for Chinchilla), what actually happens at 4e25 FLOPs (6e23 FLOPs for Chinchilla) is all extrapolation, in both cases, there are no isoFLOP plots at those scales. At least Chinchilla has Gopher as a point of comparison, and there was only 200x FLOPs gap in the extrapolation, while for Llama 3 405 the gap is 4000x.
So data needs grow faster than parameters with more compute. This looks bad for the data wall, though the more relevant question is what would happen after 16 repetitions, or how this dependence really works with more FLOPs (with the optimal ratio of tokens to parameters changing with scale).