I have not finetuned GPT-3, but I have done a lot of finetuning with GPT-J 6.1B, which is similar in scale and performance to GPT-3 “Curie.”
In my experience, doing more than a single epoch is always harmful when finetuning GPT-J.
I initially thought it was beneficial on one specific dataset, but that turned out to be the exception that proves the rule. I inspected per-token validation loss on that dataset over the course of training, and discovered that the train/val split was imperfect. Training beyond the first epoch only helped on text that had been accidentally duplicated between train and val, and was harmful elsewhere. In other words, it was “helpful” for exact memorization, but harmful for generalization.
I have a wandb report here with some plots of this phenomenon. I’m still not sure whether it’s an indication of the sample efficiency associated with the ~6B scale, a quirk of GPT-J specifically, or (less plausibly) a quirk or bug in the codebase used to tune it.
I did this work before OpenAI released their finetuning feature, and was surprised to find them defaulting to 4 epochs. Especially given that their feature has a relatively tiny maximum dataset size. My gut feeling is that 4 epochs is way too many, given a large model and only 2.5M tokens.
If 4 is not simply a bad default, maybe they considered more data with a high inferential distance (foreign, non-natural/formal languages), which may require more epochs?
I have not finetuned GPT-3, but I have done a lot of finetuning with GPT-J 6.1B, which is similar in scale and performance to GPT-3 “Curie.”
In my experience, doing more than a single epoch is always harmful when finetuning GPT-J.
I initially thought it was beneficial on one specific dataset, but that turned out to be the exception that proves the rule. I inspected per-token validation loss on that dataset over the course of training, and discovered that the train/val split was imperfect. Training beyond the first epoch only helped on text that had been accidentally duplicated between train and val, and was harmful elsewhere. In other words, it was “helpful” for exact memorization, but harmful for generalization.
I have a wandb report here with some plots of this phenomenon. I’m still not sure whether it’s an indication of the sample efficiency associated with the ~6B scale, a quirk of GPT-J specifically, or (less plausibly) a quirk or bug in the codebase used to tune it.
I did this work before OpenAI released their finetuning feature, and was surprised to find them defaulting to 4 epochs. Especially given that their feature has a relatively tiny maximum dataset size. My gut feeling is that 4 epochs is way too many, given a large model and only 2.5M tokens.
If 4 is not simply a bad default, maybe they considered more data with a high inferential distance (foreign, non-natural/formal languages), which may require more epochs?
I cannot access your wandb, btw. It seems to be private.
Whoops, fixed.