Deconfusing In-Context Learning
I see people use “in-context learning” in different ways.
Take the opening to “In-Context Learning Creates Task Vectors”:
In-context learning (ICL) in Large Language Models (LLMs) has emerged as a powerful new learning paradigm. However, its underlying mechanism is still not well understood. In particular, it is challenging to map it to the “standard” machine learning framework, where one uses a training set to find a best-fitting function in some hypothesis class.
In one Bayesian sense, training data and prompts are both just evidence. From a given model, prior (architecture + initial weight distribution), and evidence (training data), you get new model weights. From the new model weights and some more evidence (prompt input), you get a distribution of output text. But the “training step” and “inference step” could be simplified to a single function:. An LLM trained on a distribution of text that always starts with “Once upon a time” is essentially similar to an LLM trained on the Internet but prompted to continue after “Once upon a time.” If the second model performs better—e.g. because it generalizes information from the other text—this is explained by training data limitations or by the availability of more forward passes and therefore computation steps and space to store latent state.
A few days ago “How Transformers Learn Causal Structure with Gradient Descent” defined in-context learning as
the ability to learn from information present in the input context without needing to update the model parameters. For example, given a prompt of input-output pairs, in-context learning is the ability to predict the output corresponding to a new input.
Using this interpretation, ICL is simply updating the state of latent variables based on the context and conditioning on this when predicting the next output. In this case, there’s no clear distinction between standard input conditioning and ICL.
However, it’s still nice to know the level of abstraction at which the in-context “learning” (conditioning) mechanism operates. We can distinguish “task recognition” (identifying known mappings even with unpaired input and label distributions) from “task learning” (capturing new mappings not present in pre-training data). At least some tasks can be associated with function vectors representing the associated mapping (see also: “task vectors”). Outside of simple toy settings it’s usually hard for models to predict which features in preceding tokens will be useful to reference when predicting future tokens. This incentivizes generic representations that enable many useful functions of preceding tokens to be employed depending on which future tokens follow. It’s interesting how these representations work.
A stronger claim is that models’ method of conditioning on the context has a computational structure akin to searching over an implicit parameter space to optimize an objective function. We know that attention mechanisms can implement a latent space operation equivalent to a single step of gradient descent on toy linear-regression tasks by using previous tokens’ states to minimize mean squared error in predicting the next token. However, it’s not guaranteed that non-toy models work the same way and one gradient-descent step on a linear-regression problem with MSE loss is simply a linear transformation of the previous tokens—it’s hard to build a powerful internal learner with this construction.
But an intuitive defense of this strong in-context learning is that models that learn generic ways to update on input context will generalize and predict better. Consider a model trained to learn many different tasks, where the pretraining data consists of sequences of task demonstrations. We can end up with three scenarios:
The model learns representations of all demonstrated tasks. At inference time, after a few examples, it picks up on which of the known tasks is being demonstrated and then does that task. It can’t handle unseen tasks.
The model learns generic representations that allow it to encode a wide range of tasks. At inference time, it handles unseen task demonstrations by composing those generic representations into a representation of the unseen task. However, the generic representations have a limited flexibility and won’t be expressive enough to handle all unseen tasks.
The model develops internal learning machinery that can handle learning a much wider range of unseen tasks by searching over an implicit parameter space to optimize some function (not the model’s own loss function). This can be seen as implementing the most general version of conditioning on the context.
But some people elide these distinctions.
Here’s an interesting paper related to the (potential order of emergence of) the three scenarios: Pretraining task diversity and the emergence of non-Bayesian in-context learning for regression.