A starting point for making sense of task structure (in machine learning)
ML models can perform a range of tasks and subtasks, some of which are more closely related to one another than are others. In this post, we set out two very initial starting points. First, we motivate reverse engineering models’ task decompositions. We think this can be helpful for interpretability and for understanding generalization. Second, we provide a (potentially non-exhaustive, initial) list of techniques that could be used to quantify the ‘distance’ between two tasks or inputs. We hope these distances might help us identify the task decomposition of a particular model. We close by briefly considering analogues in humans and by suggesting a toy model.
Epistemic status: We didn’t spend much time writing this post. Please let us know in the comments if you have other ideas for measuring task distance or if we are replicating work.
Introduction
It might be useful to think about computation in neural networks (and in LMs specifically) on sufficiently complex tasks as a combination of (a) simple algorithms or circuits for specific tasks[1] and (b) a classifier, or family of classifiers, that determine which simple circuits are to be run on a given input. (Think: an algorithm that captures (some of) how GPT-2 identifies indirect objects in certain cases combined with a method of identifying that indirect object identification is a thing that should be done.[2]) More concretely, some pairs of tasks might overlap in that they are computed together much more than are other pairs, and we might want to build a taxonomic tree of tasks performed by the model in which tree distance between tasks is a measure of how much computation they share.[3] For example, a particularly simple (but unlikely) task structure could be a tree of depth 1: the neural network has one algorithm for classifying tasks which is run on all inputs, and then a single simple task is identified and the corresponding algorithm is run.
Why understanding task structure could be useful
Interpretability
We might hope to interpret a model by 1) identifying the task decomposition, and 2) reverse-engineering both what circuit is implemented in the model for each task individually, and how the model computes this task decomposition. Crucially, (1) is valuable for understanding the internals and behavior of neural networks even without (2), and techniques for making progress at it could look quite different to standard interpretability methods. It could directly make the rest of mechanistic interpretability easier by giving us access to some ground truth about the model’s computation—we might insist that the reverse engineering of the computation respects the task decomposition, or we might be able to use task distance metrics to identify tasks that we want to understand mechanistically. Further, by arranging tasks into a hierarchy, we might be able to choose different levels of resolution on which to attempt to understand the behavior of a model for different applications.
Learning the abstractions
Task decomposition can give direct access to the abstractions learned by the model. Ambitiously, it may even turn out that task decomposition is ‘all you need’—that the hard part of language modeling is learning which atomic concepts to keep track of and how they are related to each other. In this case, it might be possible to achieve lots of the benefits of full reverse engineering, in the sense of understanding how to implement a similar algorithm to GPT4, without needing good methods for identifying the particular way circuits are implemented in any particular language model. Realistically, a good method for measuring task similarity won’t be sufficient for this, but it could be a helpful step.
Unlearning capabilities
We’d like to be able to train models which have certain capabilities but not others. For example, we might want to train a model that can make advancements in vaccine design, but is incapable of designing a bioweapon, perhaps via unlearning bioweapons capabilities. Clearly if two tasks are similar enough, it is not possible to destroy performance at one without affecting the other. Access to the task hierarchy would allow us to understand which capability combinations are feasible.[4] In addition to helping technical researchers, this would be useful for helping policymakers understand the tradeoffs that must be made for good AI regulation. It may also be helpful for doing capability evaluations in models (perhaps after we attempt to remove a capability with unlearning) when we are worried they are sandbagging the eval for deceptive reasons — we can study the model’s performance on similar tasks and become suspicious if the model’s capabilities seem to not respect the task decomposition.
Quantifying generalization
Having a way to quantify the distance between tasks could lead to a way to measure the ability of a model to generalize[5] by providing a standard unit of ‘generalization distance’ that transfers across tasks and types of intelligent systems (eg. humans and neural networks).[6] Other than being of object-level interest, this is helpful for evaluating and predicting capabilities. Indeed, the ability to generalize (which is intuitively related to the ability to learn quickly) is often cited as a key limitation of present ML models compared to humans. We think it’d be interesting to compare generalization distance in models and humans, e.g., for forecasting when model generalization performance will beat human generalization performance (i.e., maybe, when we’ll have AGI). It may also be possible to use distance metrics to track model generalization in more fine-grained ways, e.g. by comparing the input clusterings of different Pythia checkpoints to see when certain inputs first come to be seen as similar by the model, or comparing the subtask clusters of GPT-3 to those of GPT-4, potentially seeing certain clusters merge or split.
Learning how the world works
More generally, science is about identifying the structure and patterns in the world; the task taxonomy learned by powerful language models may be very convergent and could be a useful map for understanding the territory of the world we are in. What’s more, such a decomposition would itself be of scientifico-philosophical interest — it would tell us something about thinking.
Some Subtleties
What is a task?
For defining distance metrics between tasks, it is useful to have an operationalization of a ‘task’. In this post, when we speak about task similarity, a task is specified by providing a dataset (of inputs, or of input-output pairs).[7] For example, the Indirect Object Identification task is specified by providing a dataset of pieces of tasks and completions. There are some concerns to keep in mind here:
A dataset that we provide to specify a task might not form a natural cluster for the model. We could well be missing a kind of input that the model treats similarly, or be including multiple classes of inputs which the model treats very differently. Indeed, we could be carving reality quite orthogonally to how the model carves it — and plausibly, this is the default.
We’d like to have a way to decompose inputs into tasks in an unsupervised way so that we can discover the tasks rather than guess them. This is particularly important if we want to have a method of task decomposition that can scale.
Some distance metrics can be applied to measure the task distance between individual data points rather than data sets, which could allow us to create a weighted graph between data points[8]. Clustering on this graph (or perhaps fuzzy clustering which puts nodes into multiple clusters, a bit analogous to sparse autoencoding, or hierarchical clustering which arranges clusters in a hierarchy) may allow us to identify tasks. Unfortunately, some distance metrics only work on datasets of several inputs.[9]
Task decomposition in the dataset vs a particular system’s task decomposition
We will sometimes talk about the task decomposition (of, e.g., natural language) without referring to a particular reference system that is attempting to do the tasks[10]; and sometimes about how some particular model or another (or a human) (implicitly) decomposes natural language into tasks. Here are some ways in which these are related: (1) each can provide a helpful guess for the other; (2) alternatively, one could argue that the former really only makes sense as a case of the latter with the observer left implicit (though we think there’s more to the former than that); (3) uniformity across observers of the latter (is worth investigating and) could help establish that the former is a sensible thing to consider. But we won’t track this distinction.
Absolute vs relative metrics vs clusterings
Most of the metrics provided below are not intended to output individually meaningful/interesting numbers; indeed, most are not actually metrics in the precise mathematical sense. However, the numbers can become meaningful when compared to other outputs. For example, it’s hardly meaningful to say that a pair of inputs have a certain kind of similarity , but it could begin to be meaningful in a context where other similarities are . And even if these similarities are also somehow wrongly normalized — the ordering of these similarities is sometimes not that of the ‘true similarities’ — clusterings could still be meaningful. More generally, we won’t be mathematically careful (that said, we will try not to get anything ‘wrong’). For example, we will not discuss which clustering algorithm is most appropriate in a particular context. To be clear: we consider it obviously valuable to be mathematically careful — it’s just outside the scope for now.
Methods for gauging task structure in ML
In this section, we specify a number of ways one can try to measure task similarity.[11]
Inspecting activations
The activation-based metrics below are trying to get at task similarity by measuring whether the representations computed on two tasks (or two inputs) are similar or otherwise related.
Activation distance: Let and be sets of activation vectors (perhaps from a middle layer[12]) obtained by taking instantiations of task and respectively and passing them through the model. A similarity metric for the two tasks and is, for instance, Instead of using the norm here, we could also use cosine similarity, or maybe run the data points through an SAE and compute the correlation between SAE coefficients, or use some other more principled norm (maybe from Park et al.?). To compare activations on particular inputs, replace the averages by the activation vectors on those particular inputs. Instead of comparing activations that live in a particular activation space, we could of course also compare the activations that e.g. a particular OV part outputs (this is related to Variengien, 2023). Instead of expecting activation vectors to be close for similar tasks, we could alternatively expect them to vary along similar directions — in this case, perhaps we’d want to compare sample covariance matrices of the two activation vector data sets instead.
Probe transfer: Take two tasks and in the context of both of which some intuitive concept makes sense. For example, task might be a set of English sentences labeled ‘true’ and ‘false’, and task might be a set of Spanish sentences similarly labeled. Then, train a logistic regression to predict from middle layer activations whether an English sentence is true or false. Then test how well that probe transfers to the Spanish sentences: This is supposed to give a way to access whether a model sees the two tasks in terms of the same concepts. One could also do something similar with concept ablations — i.e., checking if they transfer (either in terms of changing behavior in a reasonable way, or in terms of making the same concept inaccessible (see Belrose et al.) in the other context as well).
Mixture of experts overlap: Take two inputs and , and pass them through the model, and see if they go through the same expert. This is perhaps particularly interesting in switch transformers, where we can get a more fine-grained view of expert similarity (whether or not the two inputs go through the same expert isn’t a binary — one can examine whether the choices made in the many switch layers line up). Formally, in the switch transformer case, we can let the distance be the fraction of switch layer choices that do not match.
Activation steering transfer: Take some pair of tasks and . First, obtain some steering vector representing a meaningful disposition towards task (e.g. a disposition to play ‘cooperate’ in a prisoner’s dilemma). Then, steer the model on task with the steering vector. Compare the resulting behavior on to baseline behavior on . So, the metric is, say, [13]
Inspecting learning
Here, we discuss methods that gauge task structure via examining a model’s learning. The first three similarity metrics below are supposed to track whether two tasks benefit from [the same things being learned] / [the same existing internal structures being reinforced]. The last metric below tries to get at whether two behaviors were learned from similar sets of examples. One way these might differ from the metrics above is that it is possible some of these metrics would already begin to be meaningful before the model has interesting fully-formed internal structures.
-
Fine-tuning transfer (roughly used in Gritsevskiy & Popper): Take some pair of tasks and . First fine-tune on task , then briefly fine-tune on task , and compare performance on task to the case where we fine-tune (for equally many steps) on throughout.[14] So let’s say where the conditional notation is meant to specify what the model has been fine-tuned on, with representing a small fine-tuning set of task .
-
Few-shot prompting transfer: Repeat the above point but with in-context learning. That is, few-shot prompt with examples of vs examples of (maybe with a smaller number of at the end again), then ask it to do . So, let’s say,
-
Pre-training transfer: Repeat the above point but with more thorough training, perhaps seeing how much cutting from pretraining hurts performance on , or how much cutting and hurts performance on more than just cutting . So, let’s say, Or, analogously to the fine-tuning transfer metric above, we could compare accuracy on [when one trains solely on and fine-tunes on ] to when one just trains on for equally many steps.
-
Influence overlap: Compute the influence (for instance, formalized as in Grosse et al.) each training data point in some reference set would have (or had) on the outputs the model gives on ; store these influences in a vector (indexed by the reference set). Define similarly. Then we could see whether and are influenced by the same data points in the reference set. We could pick to be some canonical distance between and (e.g., the cosine distance).
Inspecting weights
Here, we discuss methods that gauge task structure via inspecting/changing a model’s weights. Note that the first three metrics in this section would have fit equally well under the above subsection on learning.
-
Gradient similarity (from Michaud et al.): Take two tasks (or single input-output pairs) and . Compute the gradients on the two tasks. We could then let
-
Update relevance: Alternatively, we could see how much a gradient update on one input affects the computation on another, perhaps measured in terms of the distance between activation vectors (or just the KL-divergence of the output distribution) on the other output before and after the update, perhaps additively normalized by the same quantity if the update were made on the other input itself.
-
Optimal grafting overlap (from Panigrahi et al.): Fine-tune a model on some task . Then go back to the original model, and carefully pick a tiny set of parameters to replace with parameters of the fine-tuned model, making the performance on the task as good as possible (for a precise specification, see Panigrahi et al.). One can then measure the distance between two tasks by how different their sets of grafting parameters are. Formally, letting the grafting index sets be , we could let . (See also Bayazit et al. to specify an ablation variant.)
-
Weight ablation transfer/similarity: Take two tasks and . For each model component (or weight in a more fine-grained sense, e.g., the edge between two neurons), (zero/resample) ablate it and measure the change in performance in , as well as the change in performance on . Storing these KL divergences (or, alternatively, absolute values of the changes in the loss) in two vectors , the metric could then be, say,
Analogues in humans
-
Many of the above metrics transfer — somewhat gracefully — to humans. For example, it is easy to imagine metaphorically fine-tuning a mathematician to be good at, say, physics. There would likely be a reasonably large transfer, and the same would likely hold in an LLM trained on mathematics and fine-tuned on physics. It is intellectually interesting whether our intuitive, human task clusterings transfer into LLM ontologies. Is there — pretty please — a universal ontology?
-
The existence of a somewhat universal task similarity metric would be good news for an avenue of research that aims to compare the generalization ability of ML models to that of humans. Given such a metric, one could then perhaps compare [the task similarity distance that humans can generalize across] to [the task similarity distance that GPT-k can generalize across]. In fact, perhaps one could produce a scaling law in generalization distance, and use it to estimate the time to AGI along the default path.
A toy model for testing task decomposition techniques
We briefly propose a family of toy data sets with custom chosen ‘ground truth’ task decompositions, in the sense that for each data set, there is a particular task decomposition [a model which ends up getting low loss when trained on the data set] would plausibly learn.
For the toy model, we create an artificial set of tasks with relationships that we choose. We build on the toy multitask sparse parity (MSP) task from Michaud et al. In the MSP task, each input bitstring is split into control bits and task bits. The number of control bits is equal to the number of subtasks; the control bits are always set to 0s except for having a 1 in one token position, identifying which subtask is to be performed. The task bits can be 0 or 1 freely, and each subtask is to calculate the output of a particular boolean function on the task bits, with the particular choice of boolean function/subtask specified by which control bit is present (let’s say that the assignment of boolean functions to control bits is arbitrary). The suitability of the MSP task comes from us having access to the ground truth task decomposition in this case: it is a lookup table (or depth 1 tree) of disjoint subtasks. We can straightforwardly modify the set of subtasks[15] to give them interesting relational structure in a number of ways:
Make each subtask a multilayer boolean circuit, composed by selecting each layer from a pair. For example, let where each of is a one-layer small boolean circuit. Then, each subtask can be indexed by a binary sequence. For example, subtask 3 could be indexed by the sequence 011 corresponding to the function . In this case, ground truth task similarity could be the Hamming distance between the two binary sequences.
Similar to the above, but instead of using the same pair of layer options for all subtasks, have a binary tree structure, so sequences like 01101 and 01111 correspond to tasks that agree only up to the first 3 layers. Then, ground truth task similarity would be measured by the number of shared first layers (tree distance), and the goal would be to reverse engineer the tree.
Pick a set of index sets, and make each task be computing the AND of XORs across two index sets.[16] The ground truth distance between tasks could be if they share an input subset, and if they do not.
Make each task be to calculate the XOR of a certain subset of task bits. Ground truth task similarity in this case could be the size of the overlap of the subsets of bits involved.
A point of studying these toy models is to give us some feedback on how good different distance metrics are and what it is precisely that each one measures[17] (although language models are of course likely to be different from the toy model in important ways).
Acknowledgements
Thanks to Andis Draguns and Lawrence Atkins for helpful discussions, including contributing a couple methods; to Clem von Stengel, Lucius Bushnaq, Nina Rimsky, Robert Avery, Dmitry Vaintrob, Caspar Oesterheld, and Hoagy Cunningham for discussions, comments, and edits; and potentially to people we’ve forgotten (feel free to message us).
- ↩︎
What we are calling a ‘task’ is similar to what Arora & Goyal call a ‘skill’. ‘Task’ takes the perspective of asking a model to do something; ‘skill’ takes the perspective of the model. This notion of ‘task’/’skill’ is also very similar to what Michaud et al. call a `quantum’. These authors also make certain assumptions about skills/quanta that we think of as providing interesting concrete cases. So, the picture we present here differs from the pictures proposed by these authors in that we have a looser notion of task decomposition having to do with ‘how much computation is shared’ / ‘how similar the computation is’ which could be made precise in various ways (including using ideas from these authors). It also differs in that, at least in the tree picture of tasks we present, the tasks have internal structure that can be shared with other tasks — they could be composed of shared circuits/skills/quanta. But the picture here is much inspired by Michaud et al..
- ↩︎
A possible objection here: wouldn’t the ideal indirect object identification circuit be more like a full description of what a model does to do IOI; i.e., isn’t the task dictionary/classifier part of this decomposition unnecessary? So, couldn’t it be more like: there’s a bunch of circuits for various tasks that are always running, except perhaps not having some nodes activate because of looking for something that does not exist in the input, or something, and then the final answer is some aggregation of the outputs of the circuits that do activate? Well, maybe it could be like that (or, at least, we won’t get into an extended analysis whether it could be like that here), but as far as we can tell, Wang et al. is not significant evidence of that — the method used would plausibly not detect computations with outputs shared by everything on the dataset on which its mean ablations are computed. In particular, its method would not find a hypothetical task classifier (which could well be more complex than the circuit found) which always decides that the task is indeed IOI on the reference data set (this is also true for resampling ablations from the same data set). In any case, even if the correct hypothesis to entertain were that models are more like ensembles of unconditional circuits that always ‘try to run’, the present bullet point would still make sense, mutatis mutandis.
- ↩︎
Besides conceptual reasons to think a taxonomy is appropriate, work like Saxe et al. provides motivation for a tree-like structure.
- ↩︎
A slightly more nuanced model of unlearning is: almost certainly the path of things that must be learned for vaccine design and bioweapon design is very similar, with a fork at the end. Unlearning bioweapon design is not a binary thing, but a spectrum from just superficially not outputting bioweapon advice without a jailbreak to reinitialising all the parameters in the network. One way of quantifying the degree of unlearning is how many steps of fine tuning are required to reintroduce the capability. If we want to unlearn bioweapon design without unlearning vaccine design then we can walk the model back up the bioweapons path until we hit the fork: the further the fork is from the end of the path (corresponding to higher task distance), the more deeply we are able to unlearn bioweapons without affecting vaccine capabilities. Equivalently, the more deeply we unlearn bioweapons, the more ‘collateral damage’ we necessarily pick up in terms of unlearning other things by accident, in order of increasing task distance from bioweapons. One problem with this picture is that it might not be a good way of describing the capabilities of a generally intelligent system which has learned how to learn about the world efficiently (eg. a system capable of making research advances) because it may be impossible to unlearn bioweapon design in this system such that the system could not rediscover bioweapons on the fly without unlearning general reasoning capabilities.
- ↩︎
It seems reasonable to operationalize generalization as applying understanding of a task (say, writing English poetry) to other subtasks (e.g. writing French poetry) of a certain natural “metatask” (writing poetry).
- ↩︎
Generalization distance clearly depends on things like the amount of allowed fine-tuning, the number of few-shot examples etc, and some of these things can be hard to compare to a human, but one might hope that we can fix an allowed amount of fine-tuning/prompting and still end up with something that makes sense.
- ↩︎
Roughly equivalently, we can alternatively think of a task as being specified by a distribution (of inputs or input-output pairs).
- ↩︎
Given a way to measure the similarity between two inputs, one can measure the similarity between two data sets with the expectation of the similarity between a random input from the first and a random input from the second.
- ↩︎
Still, we think there are likely reasonable ways to go from certain task-wise metrics to task decompositions — for instance, minimizing the sum of distances of each proposed task to itself minus the sum of distances between different proposed tasks — but we haven’t thought about this carefully.
- ↩︎
This is different from topic modeling, though we don’t rule out that approaches from topic modeling could be brought to bear here. When the domain is natural language, we are not looking for a partition of contexts/documents into topics here, nor quite a partition of words into topics (which topic modeling methods provide), at least in the sense of usual topic modeling methods. What we have in mind is more like a classification of which pattern(s)/rule(s) was(/were) used to generate each token, or more correctly, which pattern(s)/rule(s)/skill(s) might most naturally be used to predict each token, and (while admittedly not being very familiar with this literature) we don’t expect that standard topic modeling techniques would get at this with the level of sophistication we’d like.
- ↩︎
We note that each method below could well turn out not to measure any reasonable kind of similarity. We also note that the methods would likely end up measuring distinct flavors of task similarity, but we do not provide a detailed analysis of these flavors.
- ↩︎
We might similarly want to upweight contributions from middle layers in many scores below.
- ↩︎
Here and later, it would also make sense to look at the change in a more fine-grained manner, i.e., to not just track this single parameter.
- ↩︎
We add the small amount of fine-tuning on at the end because we want the model to be able to make some amount of connections between what it has learned from and the new domain .
- ↩︎
We also will probably want each task to be equally frequent in the dataset, unlike the original MSP task, which was designed with a different purpose in mind.
- ↩︎
That is, if the subsets for that task are and , then the task is to compute .
- ↩︎
One can make progress here by running an experiment or by just thinking through what each task decomposition method would capture when applied on a plausible NN-implementation of an algorithm solving the task.
I find this focus on task structure and task decomposition to be incredibly important when thinking about what neural networks are doing, what they could be doing in the future, and how they are doing it. The manner in which a system understands/represents/instantiates task structures and puts them in relation to one another is, as far as I can tell, just a more concrete way of asking “what is it that this neural network knows? what cognitive abilities does it have? what abstractions is it making? under what out of distribution inputs will it succeed/fail, etc.”
This comment isn’t saying anything that wasn’t in the post, just wanted to express happiness and solidarity with this framing!
I do wonder if the tree-structure of which-task and then task algorithm is what we should expect, in general. I have nothing super concrete to say here, my feeling is just that the manners in which a neural network can represent structures and put them in relation to eachother may be instantiated differently than a tree (with that specific ordering). The onus is probably on me here though—I should come up with a set of tasks in certain relations that aren’t most naturally described with tree structures.
Another question that comes to mind is, is there a hard distinction between categorizing which sub-task one is in and the algorithm which carries out the computation for a specific subtask. Is it all just tasks all the way down?
I would love to see someone expand on the ways we could use interpretability to learn about the world, or the structure of tasks (or perhaps examples of how we’ve already done this?). Aside from being interesting scientifically, maybe this could also help us build economically valuable systems which are more explicit and predictable?