I really like this paper. I tend to be pretty picky about papers, and find something to complain about in most of them (this will probably come up in future). I don’t have nitpicks about this paper. Every question that came up as I was reading and understanding this paper (other than questions that would require a significantly different or larger experiment, or a different slant of analysis) turned out to be answered in this paper. But more importantly, the paper does two things that characterize it as a great paper in my opinion:
It operationalizes and investigates (from the point of view of interpretability) the first interesting mechanistic model of a new behavior: namely a transition between “in-distribution” vs. “out-of-distribution” learning. This is a distinction that transposes to the “in-context learning” setting the difference between memorization and generalization.
It finds an excellent “minimal interesting” toy model for this phenomenon, where the behavior exhibits the full range of sophistication we expect but is (at least potentially) fully understandable from an interpretability point of view. More specifically, this model is analogous to (and again, in some sense an ’in-context transposition of) the revolutionary interpretability paper Nanda et al. that does a detailed mechanistic interpretability analysis of “usual” memorization vs. generalization in neural nets[1].
The setup in He et al. is that of a small transformer trying to solve a particular mathematical (and deterministic) sequence completion problem consisting of residues mod a prime p (equal to 29 in the paper), and related to the modular addition task in Nanda et al. Each task depends on a secret “context parameter,” or latent, given by a pair of residues (a,b). The latents are never openly given, but can be deduced from context[2] via a modular arithmetic computation. The paper essentially studies how (and to what extent) the transformer learns to deduce the context from examples in a text.
Latent context parameters are a common phenomenon in essentially all existing language models. Namely, part of a language model’s task in completing a text is determining
what is the language of the text (since the continuation should be in the same language)?
what is the format: nonfiction, fiction, news, academic writing, etc.?
what is the mood of the writing: happy, sad, etc.?
The standard way of modelling such behaviors is to assume that the language model has something like a “selector” variable that learns the different possible contexts (e.g. the “set of possible languages”) and then factorizes the problem of in-context learning as:
Solve the “inference” problem of figuring out the latent parameter “from context”: e.g. if your text discusses a unicorn, it is probably fiction rather than nonfiction.
Solve the “prediction” problem of completing text given knowledge (or at least a best guess) about the latent context: e.g. “now that we know the text is fiction, let’s use fiction convention to generate further text (turn off fact-checking subroutines, etc.)”
Most paradigms of in-context learning in interpretability conceptualize the first problem as interesting and the second as boring: while the first involves potentially nontrivial inference (fictional text doesn’t tend to have a disclaimer in every paragraph that says “note that this text is fictional!”), the second is just introducing an extra modifiable parameter in the “business as usual” of generating text. Here the simplest conceptualization is of the text generation processes under different latents as “fully disjoint” subprocesses that are separately encoded in the weights and executed as a simple “if-then” declaration. Already here, because of the probabilistic nature of the “inference” problem (1) above, these can combine stochastically in the text generation procedure (e.g. “if the text is fictional, the next word should be “dragon” but if it is nonfictional, the next word should be “dinosaur”. I currently predict a 50-50 chance of the text being fiction vs. nonfiction, so will predict a 50-50 probability distribution split between “dragon” and “dinosaur”). Even in this setting with step (2) reduced to a stochastic combination of disjoint processes, the question of getting the probabilities right involves potentially quite sophisticated Bayesian inference problems; a toy model for this is studied in the beautiful paper Shai et al. where the inference problem is understood in the language of computational mechanics.
Of course in LLM’s, the context dependence in problem (2) is likely more complicated (I discussed this in the “combining subgrammars” discussion in the previous post). For example the rulesets of fiction and nonfiction likely have significant overlap (both use the same grammar, similar conceptions of philosophy and a physical world, etc.), and the effect of changing the latent parameter (fiction vs. nonfiction e.g.) is probably better conceptualized by only partially modifying the internal structures: i.e. from a programming point of view, the different contexts are modifying some of the global variables of the program being run, but not running a fully separate program. Nevertheless, the standard way to conceptualize in-context learning in LLMs is to think of some stable collection of “possible contents”, with associated case-by-case modifications of the “conventional rules of text generation.
This paradigm leaves unaddressed the possibility of out-of-distribution generalization. For example if a language model trained on text that has only “conventional” human languages suddenly encounters a large text in Klingon, will it be able to continue this text in passably grammatical Klingon? At least in principle, this is plausible. If you show a large enough klingon text to a human linguist, she will (from her experience studying other languages) be able to infer enough of its structure and dictionary to be able to “create” a new linguistic context in her head, namely, Klingon, and generate text in this new context. In terms of interpreting the brain of the linguist, we might say that instead of the “boring” context where the linguist’s brain has a discrete set of possible variable “settings” corresponding to known languages, she is also able to tune these settings to new and previously unencountered contexts.
In fact language models show some limited indications of being able to do this as well: for example this classic paper shows that Bert can learn some structures while processing text in a previously unseen language[3]. Relatedly, SOTA LLMs seem very able to interpolate between existing contexts reasonably well: they are good at tasks with the type signature of “write an ML paper abstract in the style of a Shakespeare sonnet”.
Thus it’s meaningful to ask the question “how do LLMs generalize to new out-of-distribution contexts”. Since LLMs are massive and complicated, this doesn’t have a clean answer. But there are several interesting and somewhat mutually exclusive possible answers:
(A) LLMs internally learn to perform more efficiently on the training set by finding a compact functional form for the inference problem from “seeing some text” to “finding correct internal variables that generated this text” (i.e., they learn a general formula rather than a case-by-case “switchboard” of existing presets).
(B) What an LLM actually learns is not just any “nice function” from observed text to latents, but rather a universal notion of “how to learn”: more specifically, it has a bunch of latent parameters and learns to performs some gradient-descent-like search algorithms to these to recover the context.
(C) What looks like an LLM learning different contexts is just a collection of internally consistent heuristics (e.g. things like “copy” or “induction heads”), which apply across contexts and provide sufficient “scaffold” to correctly complete text in a new context. (Under an extreme form of this model, the only things that an LLM tracks are rules like subject-verb agreement and various forms of semantic consistency, and the idea is then that with a sufficiently large input text, enforcing these consistency structures leads to text that, incidentally, gets all the “latent-dependent” information – such as how verbs are conjugated in Klingon – right).
All of these behaviors seem to happen in models to some extent. However, it’s tricky to produce good experimental settings which cleanly separate them and analyze them in a formal and fully interpretable setting. One paper that does this, in the context of (B) above, is the paper Ahn et al., where the authors set up a certain iterated linear regression toy model task with latent variables. They observe that by solving instances of the problem with random latent parameters, this model learns a token-to-token gradient descent procedure that correctly fits the latents on out-of-distribution examples (i.e., examples with latent parameters not encountered in training)[4].
While this result is exciting, the hypothesis (B) examined in Ahn et al. seems unlikely to accurately describe semantic contexts in realistic LLMs, where the contexts in question are much more complex and compositional than this nice continuous linear context. This difference between the simplified model of Ahn et al. and the weirder phenomena that exist in language is analogous to the distinction between thinking of ordinary linear regression vs. modular addition as toy models for conventional (“0-shot”) learning: in particular, the linear regression model doesn’t exhibit a spontaneous “grokking” transition where out-of-distribution learning becomes possible only above a certain threshold of training time or sample diversity.
Following this chain of thought, the paper He et al. looks at a modular addition variant of in-context learning, where the relationships between the hidden latents and the text generation task are discrete and deterministic, but require representing nontrivial modular arithmetic and linear algebra notions to solve in a general way. In this setting, their results seem to show, interestingly, that models tend to learn in-distribution latents by memorizing a “switch board” mapping from observed examples to known contextual presets (i.e., following (A) but via a memorizing, rather than generalizing function from observations to latents), the circuits that do generalize out-of-distribution tend to look more like (C) above: i.e., they learn heuristics that get refined as more text is seen (and in particular, only work well in the “many-shot” context), rather than directly learning a general mapping from observations to latents as (A) would predict[5]. As this picture is very nonlinear and discontinuous, the resulting process does not (or at least not obviously) look like (B), i.e., does not seem to learn a variant of gradient descent.
In order to see what “needs to go right” to get good out-of-distribution generalization, He et al. vary the depth of the model (thus varying the possible degree of compositionality), and they vary the relationship between the diversity of contexts vs. the diversity of examples in a single context in the training data. A very nice thing they do, that vibes with the “separation of phenomena at different complexities” prescription for interpretability research I suggest here, is to look at differences between out-of-distribution examples learned at higher complexity (and in particular “with better out-of-distribution learning”) and examples learned at lower complexity, which are shown to exhibit only certain partial out-of-distribution behaviors called “ratio learning” (a special sub-case of method (C) above).
Upshots
Now that I’ve explained a zoomed-out point of view on the paper, what are the upshots to take away?
Well, the very fact that OOD generalization is possible in a mathematically sophisticated (albeit multishot) context is already interesting and suggestive of similar behaviors in more language-like contexts. The idea that the specific method learned looks more like combining context-independent heuristics into effectively context-dependent behaviors (rather than separately solving the context inference problem and then the prediction problem separately) is some partial confirmation of this being likely in other contexts (something that I’m particularly interested in in the context of work with Louis Jaburi); note though that this particular phenomenon may ultimately be a consequence of the details of the mathematical setting rather than a universal phenomenon. Finally, while this is less exciting than the “OOD generalization” insights produced by this paper, the paper provides a first interesting and mathematically sophisticated where, depending on some tunable presets, one sees a depth-2 “memorization-generalization” spectrum, where there is a spectrum of phenomena ranging from memorization of examples and contexts → generalization of examples and memorization of contexts (“in-distribution”) → generalization of both examples and contexts (“out-of-distribution”). Associated with this we should see new interesting “double-grokking” phenomena (that the paper already starts to see, with some interesting complications), we should see different learning coefficient measurements at characteristic temperatures associated to the two contexts, and finally this may be a good testbed to look for tools (necessarily beyond the naive test-train distinction) to quantify and analyze phenomena at different points in the memorization-generalization spectrum and their relationships with each other.
Acknowledgments
Much of my understanding of ICL and related phenomena comes out of discussing a deep dive by Lauren Greenspan into the subject. Discussions with Jake Mendel and Kaarel Hänni have also been useful here.
Stay tuned for next time:
Originally this was the introduction to a longer post that also explains in a bit more detail the behaviors found in the paper, and in particular reinterprets them in terms of the discussion in the previous post about ways of compositionally forming them from simpler rules, and the “spectrum between memorization and generalization”.
Writing up the details has taken longer than expected, and I decided to separate out this “introductory” material on the paper into its own post (it also provides a post for my planned daily posting schedule).
I want to stress the last point because I think its importance often goes unnoticed: when designing papers like this, choosing a good toy model is perhaps the most important thing to get right. You want the toy model to be expressive enough to exhibit the behavior you want to identify, and to distinguish it from other, simpler explanations, but simple and elegant enough that the behavior is findable, cleanly interpretable, and distinguishable from noise or alternative mechanisms. You want your “small” numbers to not be too small to trivialize your problem and your “large” numbers to be large enough to cleanly exhibit asymptotic behaviors but not too large to be computationally intractable. Finally, you want the model to be a potentially representative example of an interesting behavior that is in the same reference class as behaviors that occur in state-of-the-art systems. He et al. balances all of these factors perfectly – as I said, it is a very good paper.
Meaning in this case the “previously given part of the sequence – the overloading of the term “context” will be a perennial issue in these discussions.
Note though that the language in the paper is a pidgin of previously encountered languages, thus not quite as out-of-distribution as the Klingon example.
Note that I am slightly reading between the lines of the results in the paper here, and am not 100% sure the authors would endorse the same point of view.
Paper club: He et al. on modular arithmetic (part I)
In this post we’ll be looking at the recent paper “Learning to grok: Emergence of in-context learning and skill composition in modular arithmetic tasks” by He et al. This post is partially a sequel to my earlier post on grammars and subgrammars, though it can be read independently. There will be a more technical part II.
I really like this paper. I tend to be pretty picky about papers, and find something to complain about in most of them (this will probably come up in future). I don’t have nitpicks about this paper. Every question that came up as I was reading and understanding this paper (other than questions that would require a significantly different or larger experiment, or a different slant of analysis) turned out to be answered in this paper. But more importantly, the paper does two things that characterize it as a great paper in my opinion:
It operationalizes and investigates (from the point of view of interpretability) the first interesting mechanistic model of a new behavior: namely a transition between “in-distribution” vs. “out-of-distribution” learning. This is a distinction that transposes to the “in-context learning” setting the difference between memorization and generalization.
It finds an excellent “minimal interesting” toy model for this phenomenon, where the behavior exhibits the full range of sophistication we expect but is (at least potentially) fully understandable from an interpretability point of view. More specifically, this model is analogous to (and again, in some sense an ’in-context transposition of) the revolutionary interpretability paper Nanda et al. that does a detailed mechanistic interpretability analysis of “usual” memorization vs. generalization in neural nets[1].
The setup in He et al. is that of a small transformer trying to solve a particular mathematical (and deterministic) sequence completion problem consisting of residues mod a prime p (equal to 29 in the paper), and related to the modular addition task in Nanda et al. Each task depends on a secret “context parameter,” or latent, given by a pair of residues (a,b). The latents are never openly given, but can be deduced from context[2] via a modular arithmetic computation. The paper essentially studies how (and to what extent) the transformer learns to deduce the context from examples in a text.
Latent context parameters are a common phenomenon in essentially all existing language models. Namely, part of a language model’s task in completing a text is determining
what is the language of the text (since the continuation should be in the same language)?
what is the format: nonfiction, fiction, news, academic writing, etc.?
what is the mood of the writing: happy, sad, etc.?
The standard way of modelling such behaviors is to assume that the language model has something like a “selector” variable that learns the different possible contexts (e.g. the “set of possible languages”) and then factorizes the problem of in-context learning as:
Solve the “inference” problem of figuring out the latent parameter “from context”: e.g. if your text discusses a unicorn, it is probably fiction rather than nonfiction.
Solve the “prediction” problem of completing text given knowledge (or at least a best guess) about the latent context: e.g. “now that we know the text is fiction, let’s use fiction convention to generate further text (turn off fact-checking subroutines, etc.)”
Most paradigms of in-context learning in interpretability conceptualize the first problem as interesting and the second as boring: while the first involves potentially nontrivial inference (fictional text doesn’t tend to have a disclaimer in every paragraph that says “note that this text is fictional!”), the second is just introducing an extra modifiable parameter in the “business as usual” of generating text. Here the simplest conceptualization is of the text generation processes under different latents as “fully disjoint” subprocesses that are separately encoded in the weights and executed as a simple “if-then” declaration. Already here, because of the probabilistic nature of the “inference” problem (1) above, these can combine stochastically in the text generation procedure (e.g. “if the text is fictional, the next word should be “dragon” but if it is nonfictional, the next word should be “dinosaur”. I currently predict a 50-50 chance of the text being fiction vs. nonfiction, so will predict a 50-50 probability distribution split between “dragon” and “dinosaur”). Even in this setting with step (2) reduced to a stochastic combination of disjoint processes, the question of getting the probabilities right involves potentially quite sophisticated Bayesian inference problems; a toy model for this is studied in the beautiful paper Shai et al. where the inference problem is understood in the language of computational mechanics.
Of course in LLM’s, the context dependence in problem (2) is likely more complicated (I discussed this in the “combining subgrammars” discussion in the previous post). For example the rulesets of fiction and nonfiction likely have significant overlap (both use the same grammar, similar conceptions of philosophy and a physical world, etc.), and the effect of changing the latent parameter (fiction vs. nonfiction e.g.) is probably better conceptualized by only partially modifying the internal structures: i.e. from a programming point of view, the different contexts are modifying some of the global variables of the program being run, but not running a fully separate program. Nevertheless, the standard way to conceptualize in-context learning in LLMs is to think of some stable collection of “possible contents”, with associated case-by-case modifications of the “conventional rules of text generation.
This paradigm leaves unaddressed the possibility of out-of-distribution generalization. For example if a language model trained on text that has only “conventional” human languages suddenly encounters a large text in Klingon, will it be able to continue this text in passably grammatical Klingon? At least in principle, this is plausible. If you show a large enough klingon text to a human linguist, she will (from her experience studying other languages) be able to infer enough of its structure and dictionary to be able to “create” a new linguistic context in her head, namely, Klingon, and generate text in this new context. In terms of interpreting the brain of the linguist, we might say that instead of the “boring” context where the linguist’s brain has a discrete set of possible variable “settings” corresponding to known languages, she is also able to tune these settings to new and previously unencountered contexts.
In fact language models show some limited indications of being able to do this as well: for example this classic paper shows that Bert can learn some structures while processing text in a previously unseen language[3]. Relatedly, SOTA LLMs seem very able to interpolate between existing contexts reasonably well: they are good at tasks with the type signature of “write an ML paper abstract in the style of a Shakespeare sonnet”.
Thus it’s meaningful to ask the question “how do LLMs generalize to new out-of-distribution contexts”. Since LLMs are massive and complicated, this doesn’t have a clean answer. But there are several interesting and somewhat mutually exclusive possible answers:
(A) LLMs internally learn to perform more efficiently on the training set by finding a compact functional form for the inference problem from “seeing some text” to “finding correct internal variables that generated this text” (i.e., they learn a general formula rather than a case-by-case “switchboard” of existing presets).
(B) What an LLM actually learns is not just any “nice function” from observed text to latents, but rather a universal notion of “how to learn”: more specifically, it has a bunch of latent parameters and learns to performs some gradient-descent-like search algorithms to these to recover the context.
(C) What looks like an LLM learning different contexts is just a collection of internally consistent heuristics (e.g. things like “copy” or “induction heads”), which apply across contexts and provide sufficient “scaffold” to correctly complete text in a new context. (Under an extreme form of this model, the only things that an LLM tracks are rules like subject-verb agreement and various forms of semantic consistency, and the idea is then that with a sufficiently large input text, enforcing these consistency structures leads to text that, incidentally, gets all the “latent-dependent” information – such as how verbs are conjugated in Klingon – right).
All of these behaviors seem to happen in models to some extent. However, it’s tricky to produce good experimental settings which cleanly separate them and analyze them in a formal and fully interpretable setting. One paper that does this, in the context of (B) above, is the paper Ahn et al., where the authors set up a certain iterated linear regression toy model task with latent variables. They observe that by solving instances of the problem with random latent parameters, this model learns a token-to-token gradient descent procedure that correctly fits the latents on out-of-distribution examples (i.e., examples with latent parameters not encountered in training)[4].
While this result is exciting, the hypothesis (B) examined in Ahn et al. seems unlikely to accurately describe semantic contexts in realistic LLMs, where the contexts in question are much more complex and compositional than this nice continuous linear context. This difference between the simplified model of Ahn et al. and the weirder phenomena that exist in language is analogous to the distinction between thinking of ordinary linear regression vs. modular addition as toy models for conventional (“0-shot”) learning: in particular, the linear regression model doesn’t exhibit a spontaneous “grokking” transition where out-of-distribution learning becomes possible only above a certain threshold of training time or sample diversity.
Following this chain of thought, the paper He et al. looks at a modular addition variant of in-context learning, where the relationships between the hidden latents and the text generation task are discrete and deterministic, but require representing nontrivial modular arithmetic and linear algebra notions to solve in a general way. In this setting, their results seem to show, interestingly, that models tend to learn in-distribution latents by memorizing a “switch board” mapping from observed examples to known contextual presets (i.e., following (A) but via a memorizing, rather than generalizing function from observations to latents), the circuits that do generalize out-of-distribution tend to look more like (C) above: i.e., they learn heuristics that get refined as more text is seen (and in particular, only work well in the “many-shot” context), rather than directly learning a general mapping from observations to latents as (A) would predict[5]. As this picture is very nonlinear and discontinuous, the resulting process does not (or at least not obviously) look like (B), i.e., does not seem to learn a variant of gradient descent.
In order to see what “needs to go right” to get good out-of-distribution generalization, He et al. vary the depth of the model (thus varying the possible degree of compositionality), and they vary the relationship between the diversity of contexts vs. the diversity of examples in a single context in the training data. A very nice thing they do, that vibes with the “separation of phenomena at different complexities” prescription for interpretability research I suggest here, is to look at differences between out-of-distribution examples learned at higher complexity (and in particular “with better out-of-distribution learning”) and examples learned at lower complexity, which are shown to exhibit only certain partial out-of-distribution behaviors called “ratio learning” (a special sub-case of method (C) above).
Upshots
Now that I’ve explained a zoomed-out point of view on the paper, what are the upshots to take away?
Well, the very fact that OOD generalization is possible in a mathematically sophisticated (albeit multishot) context is already interesting and suggestive of similar behaviors in more language-like contexts. The idea that the specific method learned looks more like combining context-independent heuristics into effectively context-dependent behaviors (rather than separately solving the context inference problem and then the prediction problem separately) is some partial confirmation of this being likely in other contexts (something that I’m particularly interested in in the context of work with Louis Jaburi); note though that this particular phenomenon may ultimately be a consequence of the details of the mathematical setting rather than a universal phenomenon. Finally, while this is less exciting than the “OOD generalization” insights produced by this paper, the paper provides a first interesting and mathematically sophisticated where, depending on some tunable presets, one sees a depth-2 “memorization-generalization” spectrum, where there is a spectrum of phenomena ranging from memorization of examples and contexts → generalization of examples and memorization of contexts (“in-distribution”) → generalization of both examples and contexts (“out-of-distribution”). Associated with this we should see new interesting “double-grokking” phenomena (that the paper already starts to see, with some interesting complications), we should see different learning coefficient measurements at characteristic temperatures associated to the two contexts, and finally this may be a good testbed to look for tools (necessarily beyond the naive test-train distinction) to quantify and analyze phenomena at different points in the memorization-generalization spectrum and their relationships with each other.
Acknowledgments
Much of my understanding of ICL and related phenomena comes out of discussing a deep dive by Lauren Greenspan into the subject. Discussions with Jake Mendel and Kaarel Hänni have also been useful here.
Stay tuned for next time:
Originally this was the introduction to a longer post that also explains in a bit more detail the behaviors found in the paper, and in particular reinterprets them in terms of the discussion in the previous post about ways of compositionally forming them from simpler rules, and the “spectrum between memorization and generalization”.
Writing up the details has taken longer than expected, and I decided to separate out this “introductory” material on the paper into its own post (it also provides a post for my planned daily posting schedule).
I want to stress the last point because I think its importance often goes unnoticed: when designing papers like this, choosing a good toy model is perhaps the most important thing to get right. You want the toy model to be expressive enough to exhibit the behavior you want to identify, and to distinguish it from other, simpler explanations, but simple and elegant enough that the behavior is findable, cleanly interpretable, and distinguishable from noise or alternative mechanisms. You want your “small” numbers to not be too small to trivialize your problem and your “large” numbers to be large enough to cleanly exhibit asymptotic behaviors but not too large to be computationally intractable. Finally, you want the model to be a potentially representative example of an interesting behavior that is in the same reference class as behaviors that occur in state-of-the-art systems. He et al. balances all of these factors perfectly – as I said, it is a very good paper.
Meaning in this case the “previously given part of the sequence – the overloading of the term “context” will be a perennial issue in these discussions.
Note though that the language in the paper is a pidgin of previously encountered languages, thus not quite as out-of-distribution as the Klingon example.
Note that I haven’t carefully read this paper.
Note that I am slightly reading between the lines of the results in the paper here, and am not 100% sure the authors would endorse the same point of view.