This research was performed as part of Adrià Garriga-Alonso’s MATS 6.0 stream.
If an opinion is stated in this post saying that “we” hold the opinion, assume it’s Evan’s opinion (Adrià is taking a well-deserved vacation at the time of writing).
Evan won’t be able to continue working on this research direction, because he’s going to be offline before starting a new job at Anthropic in September! In that light, please view this post as something between a final research writeup, a peek into a lab notebook of some experiments we tried, and a pedagogical piece explaining some of the areas where Evan got stuck and had to dig in and learn things during MATS. See the end of the post for some thoughts we have about what promising future work would be in this area.
If you’re excited about this research agenda, we highly recommend that you get in touch with Adrià, or apply to work with him in a future MATS stream. We also recommend reaching out to Rohan Gupta and Iván Arcuschin, whose MATS 5.0 InterpBench work this builds on.
Summary
We briefly a few concepts in interpretability, especially: polysemanticity and techniques for testing interpretability hypotheses, with a specific focus on interchange interventions.
We discuss Interchange Intervention Training (IIT) and its Strict Interchange Intervention Training variant (SIIT), and propose some modest changes to training with these techniques that we discovered achieve better performance.
We use SIIT to train benchmark models where multiple known circuits all use the same MLP to perform different tasks. Ideally, the models we created should have polysemantic neurons within their MLPs. If not, we hope the framework we created, when properly scaled up, should allow users to create semi-realistic models with polysemantic neurons, known circuits, and training via gradient descent.
We discuss future exciting research directions, including using the benchmarks trained here as a test of Sparse Autoencoders (SAEs), and using this training technique to explore the different kinds of tasks attention heads can perform in superposition.
Contributions
We improved SIIT training; see our description below and some contributions from Evan which are publicly available in this repository (after the PR is sorted).
We defined a few crafted-by-hand benchmark cases, including robust dataset generators and correspondence maps of how those cases can be broken down and mapped into pieces that a transformer is capable of computing with MLPS and attention heads. See these cases here and graphical descriptions of them in the Appendices at the end of the post.
We defined a procedure for combining multiple hand-crafted circuits into a single model, which should be easy for users to expand with new benchmark cases. Code here.
We publicly release our trained (hopefully polysemantic) benchmark models here, and code for loading in and training those models here.
Introduction and Context
It’s currently very difficult for AI researchers to tell if an AI model is really aligned or just a good liar. We think that techniques that check for consistency between a model’s outputs and “thought processes”(e.g., probes catching sleeper agents, Hubinger 2022, Nanda 2022, Roger 2023) will be essential for catching misaligned models[1] before deployment, especially as models become more capable. Mechanistic interpretability (MI) is a developing field aiming to explain model internals (see Bereska+Gavves 2024 for a recent review). Optimistically, MI techniques could allow us to outline and characterize specific pathways (“circuits”) through models related to misaligned behavior like deception or sycophancy and then determine if those circuits are active. Robust circuits have been characterized in simplified transformers (e.g., induction and modular addition), and some have been found in small language models (e.g., IOI in gpt-2-small Wang+2022). Pessimistically, even if the current state-of-the-art techniques in MI cannot be advanced and just scale, they find safety-relevant model “features” (Templeton+2024) and linear probes could be used to flag when those features are active, hopefully correlating to when models are “thinking” in a misaligned way.
In order to determine if an AI model is lying or not by looking at the model activations, we need:
A tool for breaking down model activations into interpretable elements,
for leading labs to adopt those tools, and
to verify that those tools can find and alert us of undesirable model behavior.
To this end, over the past few weeks, we’ve done some work to create a framework for training transformer models which we hope can be used to test interpretability techniques for disentangling polysemantic activation spaces (like SAEs). The broad idea is that we define simple tasks that transformers are definitely capable of implementing as a circuit (checking if a string of parenthesis is balanced, for example), make a map of how that circuit could map onto a transformer’s components (attention heads, MLPs), and then train those circuits into a transformer. The key contribution here is that we train models with overlapping circuits, so that single MLPs are used in multiple circuits, hopefully leading to polysemantic neurons. We’ve decided not to have circuits overlap in attention heads, thought see the future work section for some ways we think this could be done in some interesting ways.
Our hope is that we can create a benchmark that is “realistic enough” (moreso than toy models) to be useful for testing SAEs, but still has known circuits. We hope that this will supplement currently occurring work that’s using SAEs to find and evaluate known circuits in pretrained language models (e.g., the IOI circuit in gpt-2-small and various circuits in pythia-70m).
Related Work
Polysemanticity and Superposition: These terms are not the same, as nicely explained in this post. Polysemanticity, or neuron polysemanticity, is the phenomenon that neurons fire in response to unrelated stimuli (a classic example is cat faces and the fronts of cars in Olah+ 2020). Superposition occurs when a network linearly represents more networks than it has dimensions / neurons (like the classic 5-features-in-2D of Elhage+2022). The superposition hypothesis is an extension of the linear representation hypothesis (e.g., Park+ 2023) stating that polysemantic neurons arise because features are stored as linear directions in activation space, but there are more features than there are dimensions, so a single neuron activates due to projections from many features. Our goal in this work is to create models with known circuits that exhibit neuron polysemanticity.
Training a model with known, sparse circuits is not an easy task. Lindner+ 2023 developed Tracr, a technique for compiling RASP programs into simple, human-readable decoder-only transformer models with monosemantic neurons. Geiger+ 2022 introduced the Interchange Intervention Training (IIT) technique, whose goal is to induce known causal structure (known circuits) into machine learning models. In short, IIT aims to ensure that operations in a known high-level computation graph directly map onto specific nodes of a low level model. Gupta+ 2024 found that models trained with IIT could use unintended nodes in their computations and extended this technique to Strict Interchange Intervention Training (SIIT). See the “Training with SIIT” section below for a walkthrough of IIT, SIIT, and their loss function. We use the framework of SIIT to train our benchmarks in this work.
Toy models. Toy models have been used to understand how models store features in superposition (Elhage+ 2022) and to study how memorization works on small input datasets (Henighan+ 2023). Toy models have also been used as a test bed for engineering monosemantic neurons in neural networks (Jermyn+ 2022). Scherlis+ 2022 introduces the concept of capacity as a measurement of polysemanticity. All of these works seek to learn fundamental truths about how neural networks behave in the hopes that these truths can be applied to more complex models like transformer language models, but their simplicity makes it difficult to determine how universal the truths discovered by studying them are. Our hope is that the benchmarks developed here (and by extension by works like Gupta+ 2024) help bridge the gap between toy models and real models by providing an intermediate test bed with some of the complexity of real models and some of the simplicity and ease of study of toy models.
Benchmark models. Thurnherr and Scheurer 2024 recently released a large set of Tracr models which can be used as benchmarks of interpretability techniques. As noted by Gupta+ 2024, the weights learned by a Tracr model are synthetic and have a very different distribution from naturally-occurring weights in models trained by gradient descent. Models trained by IIT and SIIT have weights with similar distributions to classically pretrained models. They released 17 benchmark models which are trained on the IOI task as well as tasks that are simple enough for Tracr to compile but trained using SIIT[2]. Reviewers of that paper noted that particularly promising future directions included examining SAEs and polysemanticity, the latter of which we explore here.
Connecting language model circuits with high level explanations
High & Low Level Models
Circuits are pathways through models that compute meaningful and interpretable algorithms (Olah+ 2020). To describe a circuit in a machine learning model such as a pretrained language model, we need to be identify specific tasks or computations that the model performs, and we need to identify precisely where those computations occur in the transformer. We will call the transformer the “low level model”, and we will create a “high level model” which is a hypothesis for how a circuit works in the low level model. Note that high level models are really very simple: they just describe a valid method of how a task or an algorithm can be performed.
As an example, imagine the task “add together three numbers,” and we’ll call the numbers ‘a’, ‘b’, and ‘c’, and the result ‘d’, so d = a + b + c. This operation can be performed in a few different ways, all of which return the same result. You could have an addition operation that takes two inputs and returns one output, and use that addition operation twice, or you could have an addition operation that takes three inputs and returns one output and then use that operation once. Both are valid ways to do addition, and so we can define multiple different valid “high level” models of adding together three numbers:
However, if we trained a transformer to perform this addition task, these would not all be equally valid descriptions of what the transformer does. Training via gradient descent picks one of these implementations (or a different one that we haven’t imagined here – models are very creative). Given a trained model and examining its outputs, all we know a priori is that it’s capable of doing 3-digit addition, so we can speculate that there’s a 3-digit addition circuit inside of that model. Our tasks as interpretability researchers are to:
Determine which pieces of the model are used for computation.
Determine whattask each of those pieces performs.
This information will allow you to invalidate incorrect high level models, and you’ll be left with one high level model which is a true description of the algorithm implemented by the low level model.
I want to briefly mention that in practice, language models and transformers don’t make this easy for us. There’s not always just one node in a low level model that performs a specific task. In the gpt2-small IOI circuit (Wang+ 2022), for example, multiple attention heads double up to perform the same task (see fig. 2, e.g., heads 2.2 and 4.11 are both previous token heads). Furthermore, even if you find a circuit like the IOI circuit, it typically doesn’t explain 100% of the IOI task being performed by the model! And, more troubling, models can exhibit complex phenomena like backup behavior (e.g., Rushing and Nanda 2024, where model components activate when another component is intervened with) or interpretability illusions (e.g., Bolukbasi+ 2021, Makelov+ 2023).
Interchange Interventions
Many approaches have been developed to test interpretability hypotheses of the form “does this high level model describe what my transformer is doing?” we won’t go through them in detail here, but see this blog post for a nice introduction[3]. We also want to quickly note that whether or not a high level model is a good description depends upon how the hypothesis is tested, so even a good description could be a bad description according to a different approach.
All this being said, some forms of testing hypotheses–including the ones we’ll use in this work–use interchange interventions[4]. In an interchange intervention, you take two different inputs x1 and x2. You run a forward pass on x2 and cache activations from its forward pass. Then you start running a forward pass on x1, but at one or more nodes in the model you replace the activations with the cached activations from x2 before finishing the forward pass. This corrupts the output of the forward pass of x1. If you have a high level model that perfectly explains the circuit, you should be able to do exactly the same interchange intervention on the high level model (forward pass x2, then forward pass x1 while replacing nodes in the high level model that correspond to replaced nodes in the transformer). If your high level output and low level output under this intervention are the same, then the node that was intervened on in the high level model is a good description of what’s happening in the low level model.
See below for an idealized graphic walkthrough of interchange interventions for the d = a + b + c example we described above. See below the image for more exposition.
In the top panels we show one high level model of three-digit addition; in the bottom panels we show a low level model (a 2-layer transformer with one attention head per layer, and we’re assuming all of the action happens in the attention heads and not in the MLP here). The left panels show the results of standard forward passes that produce the expected output. In the right panels, we show a forward pass for x1 that has undergone an interchange intervention using activations from the x2 forward pass.
In creating this image, I’ve assumed that the high level model is a good description of the transformer circuit, so by intervening on the same node in the high level and low level model, we get the same downstream effects and output. If the high level model were not a good description of the low level circuit, then the low-level model would have output something other than 12, and that would invalidate our high level hypothesis.
Training with SIIT
Strict Interchange Intervention Training (SIIT) trains models to minimize a loss function with three terms. Given a dataset of input-label pairs (x, y), a high-level model MHL, and given two samples from that dataset (x1, y1) and (x2, y2), the loss on those samples is calculated as:
Each of the loss terms performs the same type of calculation (they are all a cross-entropy loss for a categorical task or an MSE loss for a regression task). In words, these loss terms measure:
Lbaseline(x1,y1): the ability of the model to output y1 when given x1 as an input (standard training term, nothing fancy).
LIIT(x1,MHL(x1);x2): the ability of the model to output the same output as a high level model MHL when both the model and MHL undergo a forward pass of x1 with an interchange intervention using cached activations from a forward pass of x2.
LSIIT(x1,y1;x2): the ability of the model to output y1 when performing a forward pass with x1 with an interchange intervention (from an x2 forward pass) on a low-level node that has no function in the high level model.
dbaseline, dIIT, and dSIIT are weights that determine the relative importance of these terms.
For a more detailed description of these loss terms, we refer the reader to section 3 of the IIT paper (Geiger+ 2022) and section 3 of the SIIT paper (Gupta+ 2024).
In training, we try to produce models that achieve 100% accuracy under three metrics:
Baseline Accuracy: The model can produce the proper output, y, given an input, x.
Interchange Intervention Accuracy (IIA), described in section 3 of Geiger+ 2022. The model can produce the same output as the high-level model under an LIIT-style intervention given a pair of inputs x1 and x2.
Strict Interchange Intervention Accuracy (SIIA), described briefly in section 4 of Gupta+ 2024. The model can produce the proper output, y1, given an input, x1, and given an LSIIT-style interchange using the input x2.
Modifications to SIIT training in this work
We’ve made a few changes to the SIIT training presented in Gupta+ 2024 and available in their IIT repo, laid out below.
One vs Three steps per training batch (Adam beta parameters):
Algorithm 1 ofGupta+ 2024 updates model weights three times per batch of training data. Backpropagation is performed and a gradient step is taken after calculating each of the three terms of the loss function (L_baseline, L_IIT, L_SIIT). They found worse performance when taking a single step to update the weights using L_total. I was able to replicate this finding when using the default Adam optimization hyperparameters. We found this troubling: rather than finding a minimum in a single loss landscape, the model was iteratively finding different minima in three different loss landscapes.
The fact that a three-step loss calculation produced better results than a single step loss calculation suggests that some of the terms in the Adam optimizer with “memory” (specifically the Momentum \beta_1 term or the learning rate \beta_2 term) can better optimize models via SIIT when they have a “cool down” between individual steps of the loss terms. To improve the optimizer’s ability to forget past iterations, we can therefore decrease these \beta values (default values of these parameters are \beta_1 = 0.9 and \beta_2 = 0.999) to help the optimizer achieve model convergence using a single loss function and single parameter update per training batch.
I find that decreasing \beta_2 from 0.999 to 0.9 while keeping the default \beta_1 = 0.9 allows me to achieve model convergence using L_total.
Changes to SIIT node sampling
In Gupta+ 2024, L_SIIT is calculated by sampling a single node (an attention head or MLP) that isn’t used in the high level model, then performing an interchange intervention on that single node. This gets closer to strictly enforcing the high / low level mapping IIT tries to create, but we speculate that this could have flaws related to backup behavior.
Imagine you have a model with three attention heads and one MLP. One attention head and the MLP are explicitly used in the high level model and undergo IIT-style interventions. The other two attention heads are not used in the circuit, so SIIT works to penalize those nodes. If only one node is sampled each training step, then the two nodes outside of the circuit “take turns” being active or being intervened on. We speculate that the model could find a way to use those heads to help it achieve lower loss (perhaps they both perform the same task as the node in the circuit, and the downstream MLP looks for consistency between the inputs and their outputs and throws out the results of the one intervened-on head).
To combat the ability of models to learn this backup-like behavior through SIIT, we instead sample a subset of each of the unused nodes for each SIIT loss evaluation. Each unused node is sampled (or not) with a probability of 50%. So in the above example, sometimes the model will have access to both extra attention heads, sometimes it’ll have access to one or the other, and sometimes it’ll have access to none. Our hope is that by intervening randomly on the full subsets of nodes in the SIIT loss, we improve the robustness of the SIIT training method.
Other small changes
Gupta+ 2024 used a fixed learning rate of 10-3, whereas we adopt a linearly decreasing learning rate schedule. We use an initial learning rate of 10-3, then linearly decrease to a learning rate of 2 x 10-4 over the course of training. We still finish training early if IIA and SIIA of 100% are achieved on the validation dataset.
We also use different weights in the loss term. Gupta+ 2024 mostly used {d_baseline = 1, d_IIT = 1, d_SIIT = 0.4}, and we use {d_baseline = 0.4, d_IIT = 1, d_SIIT = 0.4}. Our thinking is that the SIIT term and the baseline term in the loss function are working towards the same task (having the model output the right answer), but they work in opposition: the baseline term tries to use the whole model to get better loss overall, while the SIIT term works to get lower loss while discarding unimportant terms. Rather than overweighting the baseline, we choose to weight these terms equally. The IIT term, on the other hand, is the term that aims to produce consistency with the operations in the high level model, so we leave its magnitude alone. We choose to have d_baseline + d_SIIT < d_IIT because matching the high-level model is the hardest training task, and we want to weight that a bit more heavily.
Finally, when evaluating if the model achieves 100% IIA on the validation data, we loop over all of the nodes in the correspondence between the high- and low-level models, compute IIA for when each individual node is interchanged, then take the mean result over all nodes. Previously, IIA was only calculated for a single node each validation batch.
Building Polysemantic Benchmark Transformers
Our goal is to create transformer models that perform known circuits and have polysemantic neurons. In theory, we could force the models to have polysemantic neurons by mapping high-level computations onto subspaces of low-level MLPs where those subspaces all project onto a single (or a few) neurons. However, we want polysemanticity to arise naturally as a result of training via SGD, so we choose not to constrain the activations paces in that manner in this work.
The overall process that we walk through to create models which we hope exhibit polysemantic neurons (we unfortunately didn’t have the chance to test this) is walked through below. In general, the process looks like this:
Construct a high-level model that can perform a simple task.
Train a circuit into a transformer that performs the same tasks as the high-level model to ensure that it is properly defined in terms of units that a transformer can perform.
Construct another high-level model that just performs multiple of these validated high-level tasks.
Train all of the circuits required for the multi-piece-model into a transformer using SIIT. Importantly, the circuits should overlap within nodes of the transformer (multiple circuits should run through the same MLP).
Step 1: Designing monosemantic benchmarks
We design four benchmarks which each perform one specific algorithmic task. See the appendices for details of these four high level tasks and how we map them onto low-level transformers for training with SIIT. Two of these tasks are simple parenthesis-based tasks (inspired by this post) and two of these tasks were benchmarks in the MATS 5.0 circuits benchmark, but we have constructed by-hand circuits for them which differ from the Tracr circuits that were previously trained.
We break each algorithmic task down into successive computations that can be performed by the nonlinearities present in either a single attention head or in a single MLP layer. We design a “correspondence” which links each node in the high level model with a single attention heads or MLP in the low level model. Then we train the low level transformers using the training techniques we described above in the “Training with SIIT” section and ensure that the high level model can be trained into a transformer using SIIT.
Phew, we’re through with the (many) preliminaries! Now we can describe the meat of the interesting (or at least novel) work we did here.
Recall that our goal was to train a transformer model which has multiple known circuits running through it. Importantly, the circuits overlap, so e.g., multiple circuits run through and use a single MLP. To achieve this in the framework of IIT/SIIT, we need to construct a high level model that itself is a combination of multiple single high-level benchmarks. We could do this by hand and painstakingly construct a high level model which performs some set of N tasks, but then this method wouldn’t scale and it wouldn’t be very interesting. So: we needed to come up with ways to:
Create a dataset that is a valid combination of datasets for individual tasks.
Create a high level model which is a combination of some set of simple models that the user specifies.
Naturally, both of these come with difficulties.
Combining multiple datasets
Assume we have two datasets which have been curated for two simple algorithmic benchmark tasks.
Dataset 1 is for a parenthesis task and has a vocabulary { “BOS“ : 0, “PAD” : 1, “(” : 2, “)” : 3 }.
Dataset 2 is for a task that looks for repeating tokens and has a vocabulary { “BOS” : 0, “PAD” : 1, “a” : 2, “b” : 3, “c” : 4}.
Furthermore, let’s assume that the context length of samples in Dataset 1 is 10 while the context length of samples in Dataset 2 is 15.
A merged dataset for a polysemantic benchmark can be created by combining these two datasets. We have to make sure that each dataset entry has a flag indicating which dataset (and task) it belongs to, and we need to make sure all entries in the dataset have the same context length. We define the context length of the merged dataset to be the max context length of the individual datasets (15 in the case above) plus 1, so n_ctx for the merged dataset is 16. To make all samples fit into this context, we perform the following operations:
Add a ‘task ID’ token to the beginning of each dataset sample. So: all entries from Dataset 1 will have “0” prepended, and all entries from Dataset 2 will have “1” prepended.
Pad the end of the samples: so, all entries from Dataset 1 will have 5 “PAD” / “1” tokens added to the end.
As far as dataset creation goes – that’s it! Not too bad. But once this dataset starts interacting with the high-level model, there are more challenges; see Difficulty 1 below.
Combining multiple individual tasks into one model
We do all of our work with IIT using TransformerLens’ HookedRootModule interface. This interface is great, because it lets us do interchange interventions really easily! Unfortunately, the restrictions of this interface mean that we have a somewhat inefficient means of combining multiple high level models into one, which we’ll describe below.
To initialize a multi-benchmark model, we pass a list of individual benchmarks, each of which contains a HookedTransformerConfig (describing what size of a transformer it should be trained into) and a Correspondence object (defining which TransformerLens HookPoints in the high level model map and low level model map onto each other). The model makes its own config by taking the max() of d_model, n_layers, etc. of all of the low level models. The model makes it own correspondence by looping through each attention head and MLP that that eventual low-level model will have and noting which hooks in each of its high level models that node will be responsible for. The model also assigns an unused attention head to the task of looking at the task ID for each batch entry.
After instantiation, the model is ready to perform a forward pass using a batch of inputs from a combined dataset like we described above. The forward pass follows this algorithm:
Split the task id (first item of each batch entry) from the rest of the input and store both.
Loop through each high level benchmark
Modify the inputs, if needed, so that they can be directly fed in to each high-level benchmark. See “Difficulty 1” below. Store these modified inputs.
Loop through each high level benchmark with their modified inputs:
Run a forward pass, caching all activations during this forward pass.
Map cached activations from all of the high level models to nodes in the low level models: loop through all of the attention heads and MLPs that this model will map onto.
Using the correspondence created during instantiation, gather all the data in the cache that corresponds to operations that will be computed in this node. Stack all of this data together into a single torch Tensor.
Pass that stacked data together into a HookPoint for ease of interchange interventions
Unpack results of hooked data back into the caches for each HL model.
Loop through each high level benchmark with modified inputs to calculate the output with interventions:
Add hooks for each HookPoint in the high level model and replace forward pass activations with those from the cache.
Run a forward pass with hooks.
Put model output (probabilities) into the appropriate output indices using a boolean mask created from task IDs.
The above describes how to create the high level model. The low level model is created in exactly the same way as we would for a single-circuit benchmark case: Xavier-initialize a transformer and train it with SIIT. The accuracies of the various low level models that we trained are reported at the end of this post in an appendix.
I want to acknowledge that this is messy, and this procedure is the result of us overcoming challenges that we faced along the way in the first and simplest way that occurred to us. I’d love to see someone come up with a cleaner version of this! To help explain why we landed on this procedure, I want to call attention to a couple of difficulties that we ran into below.
Difficulty 1: Interchanges outside of the task vocabulary
One of the biggest difficulties is that two batches might cause problems under the interchange of a task ID. The example below is for two dataset examples from two tasks with different vocabularies as we described above in the “Combining multiple datasets” section. Here’s a sample of the problems we face:
This is a problem! Token 4 doesn’t exist in task 0’s vocabulary! There’s a lot of ways to create a high-level model and dataset to handle this problem:
Always have our batches have the same number of task 1 samples and task 2 samples, and only do interchanges between like samples (unfortunately, this would make it impossible to train a specific node in the low-level model to do the task-id operation, because there would be no interchange information).
Only include tasks with the same d_vocab (restricts the kinds of benchmarks you can build)
Replace out-of-vocab tokens with in-vocab tokens.
We’ve gone with the latter. For now, we’re replacing tokens that are outside of the task vocabulary with randomly sampled tokens from within the task vocab (and we do not include PAD and BOS in the random sampling, so in the above example the ‘4’ would be replaced with a ‘2’ or a ‘3’ representing a ‘(‘ or a ‘)’).
Difficulty 2: Calculating the proper loss.
Some of the more robust benchmarks in circuits-benchmark are regressive tasks where the transformer outputs a float value and which were optimized by a mean-squared-error loss; other tasks are optimized by a cross-entropy loss but are not autoregressive. We generally found that we had to group benchmark cases into three categories:
MSE-optimized
CE-optimized but not autoregressive
CE-optimized and autoregressive
We chose to train benchmarks that are CE-optimized and autoregressive to create benchmarks that are closer to being like real language models, but this choice limited the cases that we found in the literature that we were able to reuse in this work.
Furthermore, even after deciding that we wanted to calculate a per-token CE loss, we still ran into some struggles regarding how much each token entry should be weighted:
If you have one task with n_ctx = 10 and another with n_ctx = 15, should you calculate the loss on every token?
We think no. You shouldn’t calculate the loss on the pad tokens at the end of the n_ctx = 10 simulations.
If you don’t calculate the loss on the pad tokens and just calculate a simple per-token loss, then in the above example (assuming your batch is evenly comprised of both tasks) the loss function receives 60% of its information from the n_ctx=15 case and only 40% of its information from the n_ctx=10 case. But ideally we want to assign these tasks even weight. So really each batch sample should have a mean over relevant tokens in the n_ctx direction taken before finally taking a mean over the batch to get a scalar loss.
For now, we decided to use a fixed n_ctx = 15 for all of our models to avoid these issues, but it’s something to consider in the future.
Future Research Directions
We had hoped to get to the next step with this project this summer (actually training SAEs on these benchmarks and looking for polysemanticity and features), but setting up and then training these benchmarks ended up being a monumental task! Plus, SAE architectures are developing so quickly (gated, top-k, jump-relu) that we decided that in the small amount of time we had, it would be best to focus on making the framework for models to test SAEs (and hopefully making that framework not-completely-incomprehensible to others!).
All that said, the obvious next step is to train some SAEs on these benchmarks and see if there are interpretable SAE features that correspond to the (known) high-level tasks that are being trained into the nodes of these transformers. If those tasks are being represented as features in SAEs, how are they being represented? Is there a single feature for each task that we would expect there to be? Multiple? If you’re interested in doing something like this, here’s a (probably deprecated) notebook we wrote that trains SAEs using SAELens and these wrapper functions, and it should be a pretty quick project to get a few SAEs trained.
Beyond using these benchmarks as a testbed for SAEs, we think there are a lot of really interesting questions about polysemanticity that can be studied using SIIT and the framework we’ve developed here. Specifically, if this framework works like we hope, it allows us to define all of the operations that occur in a given node in a neural network, which allows us to train models that let us ask questions of the form:
Can a single attention head do multiple completely unrelated tasks? If so, how many completely unrelated tasks can it do?
Can we finetune pretrained language models to enforce and crystallize a known circuit? For example, can we take the gpt2-small IOI circuit and use SIIT or another technique to make it so that ~100% (rather than just most) of the computation is done in the “known” circuit?
We’d also be interested in research that interrogates whether IIT and SIIT actually do what we want them to do[5]. What would be the robust questions to ask to determine if the “unused” nodes in the transformer are actually not contributing to the answer in a meaningful way? Is there a better way to train known circuits into transformers in a less synthetic way than e.g., Tracr compilation? And, what are better ways to construct our high-level polysemantic benchmark other than the stuff that we’ve done to get things working this summer?
Lots of great questions – wish we had time to explore them all. If you’re interested in any of them, please feel free to reach out and dig in!
Acknowledgments
Evan is grateful to lots of people for helping him get his footing in the AI safety field; in particular: Adam Jermyn for being a great friend and mentor, Joseph Bloom for his mentorship, kindness, and great advice, Eoin Farrell and Aaron Scher for lots of great conversations and collaboration, Jason Hoelscher-Obermaier and Clement Neo for their mentorship during an Apart Lab sprint, and Adrià Garriga-Alonso for his mentorship during MATS. Evan’s also grateful to many more people he’s met in the AIS community over the course of the past year for some really excellent conversations – there are too many of you to list, but thanks for taking the time to chat! Evan was both a MATS scholar and KITP postdoctoral scholar during the time that this work was completed, so this work was partially supported by NSF grant PHY-2309135 and Simons Foundation grant (216179, LB) to the Kavli Institute for Theoretical Physics (KITP) and by MATS.
Code & Model Availability
Code used to define high level models/datasets, and notebooks for training and loading trained models are available in this git repository. Trained models are available online in this huggingface repo (see the git repo for a notebook that loads them).
Appendices
Custom High Level Models
Parentheses Balancer
This algorithmic task was inspired by this post. The task: Given a string that consists of open “(“ and closed “)” parentheses, determine if the string is balanced: that is, every open parenthesis is eventually closed, and no closed parenthesis appear when there is no open parenthesis that they correspond to. This can be broken down into a “horizon” test and an “elevation” test: first map “(“ → +1 and “)” → −1. To get the elevation, take the cumulative sum of these values over your string. The horizon test is passed if the string stays “above the horizon” (the current elevation is never less than zero). The elevation test is passed if the elevation is zero at the end of the string.
This task is a bit tricky to train into a transformer, because a good dataset isn’t trivial to create. A random string of “(“ and “) sampled uniformly will almost always be an imbalanced string, and a transformer can get ~99% accuracy just by always saying that the string is not balanced. To get around this, we created custom training datasets that consisted of:
25% balanced strings
25% strings that pass the horizon test but fail the elevation test
25% strings that pass the elevation test but fail the horizon test
25% strings that pass both tests
Note that we distinguish between different types of failures to ensure that interchange interventions would have a higher probability of changing the downstream output than they would for randomly generated failures (which would mostly fail both tests).
In order to make training this task occur on a token-by-token basis (like with the other tasks below), we evaluate whether the string so far is balanced (or not) at each token position.
A mapping between the high level model we designed and where we train them into a low level model is shown in the figure below. Orange arrows show where high level model nodes map into the low level model.
Left Greater
The task: Given a string that consists of open “(“ and closed “)” parentheses, determine if there are more left parenthesis than right parentheses. This is a simple task, and a dataset consisting of randomly-sampled 0s and 1s can train this task fine.
A mapping between the high level model we designed and where we train them into a low level model is shown in the figure below. Orange arrows show where high level model nodes map into the low level model.
Duplicate Remover
This is inspired by case 19 in circuits-benchmark. The task: Given a string that consists of “a”, “b”, and “c” tokens, remove any instances of duplicated tokens, so “a a b c c a” becomes “a PAD b c PAD a”. This is a simple task, and a dataset consisting of randomly-sampled tokens is fine.
A mapping between the high level model we designed and where we train them into a low level model is shown in the figure below. Orange arrows show where high level model nodes map into the low level model.
Unique Extractor
This is inspired by case 21 in circuits-benchmark. The task: The task: Given a string that consists of “a”, “b”, and “c” tokens, remove any instances of tokens that have previously appeared, so “a a b c c a b c” becomes “a PAD b c PAD PAD PAD PAD”. This is a simple task, and a dataset consisting of randomly-sampled tokens is fine.
A mapping between the high level model we designed and where we train them into a low level model is shown in the figure below. Orange arrows show where high level model nodes map into the low level model.
Post-training accuracies
Below is a table quoting various accuracies achieved by the pretrained models in this huggingface repo. These accuracies are calculated on the validation set used during model training, which is 2000 samples with n_ctx = 15, times the number of circuits in the model (2000 samples for a monosemantic model, 4000 samples for a polysemantic model with two circuits, etc.).
Upon reviewing these tasks, some of them do not work as intended, and some of them are pure memorization; we found their cases 3, 4, 18, 19, and 21 to be the most robust of their currently public models.
I referred to “interchange interventions” as “activation patching” before I came into MATS; they’re roughly the same, and if you want an intro to how to use a technique like this, the IOI ARENA notebook is great.
Maybe there’s a bit of a chicken-and-egg problem here. What if the SAEs trained on our benchmarks don’t produce features with the expected high-level interpretation. Does that mean that the SAEs are bad, or are the SAEs actually good and they’re telling us SIIT isn’t working as intended? How do we disentangle those outcomes?
Crafting Polysemantic Transformer Benchmarks with Known Circuits
Notes:
This research was performed as part of Adrià Garriga-Alonso’s MATS 6.0 stream.
If an opinion is stated in this post saying that “we” hold the opinion, assume it’s Evan’s opinion (Adrià is taking a well-deserved vacation at the time of writing).
Evan won’t be able to continue working on this research direction, because he’s going to be offline before starting a new job at Anthropic in September! In that light, please view this post as something between a final research writeup, a peek into a lab notebook of some experiments we tried, and a pedagogical piece explaining some of the areas where Evan got stuck and had to dig in and learn things during MATS. See the end of the post for some thoughts we have about what promising future work would be in this area.
If you’re excited about this research agenda, we highly recommend that you get in touch with Adrià, or apply to work with him in a future MATS stream. We also recommend reaching out to Rohan Gupta and Iván Arcuschin, whose MATS 5.0 InterpBench work this builds on.
Summary
We briefly a few concepts in interpretability, especially: polysemanticity and techniques for testing interpretability hypotheses, with a specific focus on interchange interventions.
We discuss Interchange Intervention Training (IIT) and its Strict Interchange Intervention Training variant (SIIT), and propose some modest changes to training with these techniques that we discovered achieve better performance.
We use SIIT to train benchmark models where multiple known circuits all use the same MLP to perform different tasks. Ideally, the models we created should have polysemantic neurons within their MLPs. If not, we hope the framework we created, when properly scaled up, should allow users to create semi-realistic models with polysemantic neurons, known circuits, and training via gradient descent.
We discuss future exciting research directions, including using the benchmarks trained here as a test of Sparse Autoencoders (SAEs), and using this training technique to explore the different kinds of tasks attention heads can perform in superposition.
Contributions
We improved SIIT training; see our description below and some contributions from Evan which are publicly available in this repository (after the PR is sorted).
We defined a few crafted-by-hand benchmark cases, including robust dataset generators and correspondence maps of how those cases can be broken down and mapped into pieces that a transformer is capable of computing with MLPS and attention heads. See these cases here and graphical descriptions of them in the Appendices at the end of the post.
We defined a procedure for combining multiple hand-crafted circuits into a single model, which should be easy for users to expand with new benchmark cases. Code here.
We publicly release our trained (hopefully polysemantic) benchmark models here, and code for loading in and training those models here.
Introduction and Context
It’s currently very difficult for AI researchers to tell if an AI model is really aligned or just a good liar. We think that techniques that check for consistency between a model’s outputs and “thought processes”(e.g., probes catching sleeper agents, Hubinger 2022, Nanda 2022, Roger 2023) will be essential for catching misaligned models[1] before deployment, especially as models become more capable. Mechanistic interpretability (MI) is a developing field aiming to explain model internals (see Bereska+Gavves 2024 for a recent review). Optimistically, MI techniques could allow us to outline and characterize specific pathways (“circuits”) through models related to misaligned behavior like deception or sycophancy and then determine if those circuits are active. Robust circuits have been characterized in simplified transformers (e.g., induction and modular addition), and some have been found in small language models (e.g., IOI in gpt-2-small Wang+2022). Pessimistically, even if the current state-of-the-art techniques in MI cannot be advanced and just scale, they find safety-relevant model “features” (Templeton+2024) and linear probes could be used to flag when those features are active, hopefully correlating to when models are “thinking” in a misaligned way.
In order to determine if an AI model is lying or not by looking at the model activations, we need:
A tool for breaking down model activations into interpretable elements,
for leading labs to adopt those tools, and
to verify that those tools can find and alert us of undesirable model behavior.
Sparse Autoencoders (SAEs) effectively break down model activations into interpretable elements, and leading labs are training SAEs on frontier models like GPT-4 and Claude-3 Sonnet. The interpretability community has fully embraced SAEs over the past year (see the lesswrong Sparse Autoencoder tag). A lot of the thrust of SAE research over the past year has been on training and distributing better SAEs (gated, top-k, jump-relu, switch, etc.). We would love to see more work focused on the validity of SAEs (e.g., is the linear representation hypothesis right?, are there better metrics than the standard ones reported?, what’s going on with the SAE error term?) which could help determine how much better we can do at explaining model internals just by scaling SAEs.
To this end, over the past few weeks, we’ve done some work to create a framework for training transformer models which we hope can be used to test interpretability techniques for disentangling polysemantic activation spaces (like SAEs). The broad idea is that we define simple tasks that transformers are definitely capable of implementing as a circuit (checking if a string of parenthesis is balanced, for example), make a map of how that circuit could map onto a transformer’s components (attention heads, MLPs), and then train those circuits into a transformer. The key contribution here is that we train models with overlapping circuits, so that single MLPs are used in multiple circuits, hopefully leading to polysemantic neurons. We’ve decided not to have circuits overlap in attention heads, thought see the future work section for some ways we think this could be done in some interesting ways.
Our hope is that we can create a benchmark that is “realistic enough” (moreso than toy models) to be useful for testing SAEs, but still has known circuits. We hope that this will supplement currently occurring work that’s using SAEs to find and evaluate known circuits in pretrained language models (e.g., the IOI circuit in gpt-2-small and various circuits in pythia-70m).
Related Work
Polysemanticity and Superposition: These terms are not the same, as nicely explained in this post. Polysemanticity, or neuron polysemanticity, is the phenomenon that neurons fire in response to unrelated stimuli (a classic example is cat faces and the fronts of cars in Olah+ 2020). Superposition occurs when a network linearly represents more networks than it has dimensions / neurons (like the classic 5-features-in-2D of Elhage+2022). The superposition hypothesis is an extension of the linear representation hypothesis (e.g., Park+ 2023) stating that polysemantic neurons arise because features are stored as linear directions in activation space, but there are more features than there are dimensions, so a single neuron activates due to projections from many features. Our goal in this work is to create models with known circuits that exhibit neuron polysemanticity.
Training a model with known, sparse circuits is not an easy task. Lindner+ 2023 developed Tracr, a technique for compiling RASP programs into simple, human-readable decoder-only transformer models with monosemantic neurons. Geiger+ 2022 introduced the Interchange Intervention Training (IIT) technique, whose goal is to induce known causal structure (known circuits) into machine learning models. In short, IIT aims to ensure that operations in a known high-level computation graph directly map onto specific nodes of a low level model. Gupta+ 2024 found that models trained with IIT could use unintended nodes in their computations and extended this technique to Strict Interchange Intervention Training (SIIT). See the “Training with SIIT” section below for a walkthrough of IIT, SIIT, and their loss function. We use the framework of SIIT to train our benchmarks in this work.
Toy models. Toy models have been used to understand how models store features in superposition (Elhage+ 2022) and to study how memorization works on small input datasets (Henighan+ 2023). Toy models have also been used as a test bed for engineering monosemantic neurons in neural networks (Jermyn+ 2022). Scherlis+ 2022 introduces the concept of capacity as a measurement of polysemanticity. All of these works seek to learn fundamental truths about how neural networks behave in the hopes that these truths can be applied to more complex models like transformer language models, but their simplicity makes it difficult to determine how universal the truths discovered by studying them are. Our hope is that the benchmarks developed here (and by extension by works like Gupta+ 2024) help bridge the gap between toy models and real models by providing an intermediate test bed with some of the complexity of real models and some of the simplicity and ease of study of toy models.
Benchmark models. Thurnherr and Scheurer 2024 recently released a large set of Tracr models which can be used as benchmarks of interpretability techniques. As noted by Gupta+ 2024, the weights learned by a Tracr model are synthetic and have a very different distribution from naturally-occurring weights in models trained by gradient descent. Models trained by IIT and SIIT have weights with similar distributions to classically pretrained models. They released 17 benchmark models which are trained on the IOI task as well as tasks that are simple enough for Tracr to compile but trained using SIIT[2]. Reviewers of that paper noted that particularly promising future directions included examining SAEs and polysemanticity, the latter of which we explore here.
Connecting language model circuits with high level explanations
High & Low Level Models
Circuits are pathways through models that compute meaningful and interpretable algorithms (Olah+ 2020). To describe a circuit in a machine learning model such as a pretrained language model, we need to be identify specific tasks or computations that the model performs, and we need to identify precisely where those computations occur in the transformer. We will call the transformer the “low level model”, and we will create a “high level model” which is a hypothesis for how a circuit works in the low level model. Note that high level models are really very simple: they just describe a valid method of how a task or an algorithm can be performed.
As an example, imagine the task “add together three numbers,” and we’ll call the numbers ‘a’, ‘b’, and ‘c’, and the result ‘d’, so d = a + b + c. This operation can be performed in a few different ways, all of which return the same result. You could have an addition operation that takes two inputs and returns one output, and use that addition operation twice, or you could have an addition operation that takes three inputs and returns one output and then use that operation once. Both are valid ways to do addition, and so we can define multiple different valid “high level” models of adding together three numbers:
However, if we trained a transformer to perform this addition task, these would not all be equally valid descriptions of what the transformer does. Training via gradient descent picks one of these implementations (or a different one that we haven’t imagined here – models are very creative). Given a trained model and examining its outputs, all we know a priori is that it’s capable of doing 3-digit addition, so we can speculate that there’s a 3-digit addition circuit inside of that model. Our tasks as interpretability researchers are to:
Determine which pieces of the model are used for computation.
Determine what task each of those pieces performs.
This information will allow you to invalidate incorrect high level models, and you’ll be left with one high level model which is a true description of the algorithm implemented by the low level model.
I want to briefly mention that in practice, language models and transformers don’t make this easy for us. There’s not always just one node in a low level model that performs a specific task. In the gpt2-small IOI circuit (Wang+ 2022), for example, multiple attention heads double up to perform the same task (see fig. 2, e.g., heads 2.2 and 4.11 are both previous token heads). Furthermore, even if you find a circuit like the IOI circuit, it typically doesn’t explain 100% of the IOI task being performed by the model! And, more troubling, models can exhibit complex phenomena like backup behavior (e.g., Rushing and Nanda 2024, where model components activate when another component is intervened with) or interpretability illusions (e.g., Bolukbasi+ 2021, Makelov+ 2023).
Interchange Interventions
Many approaches have been developed to test interpretability hypotheses of the form “does this high level model describe what my transformer is doing?” we won’t go through them in detail here, but see this blog post for a nice introduction[3]. We also want to quickly note that whether or not a high level model is a good description depends upon how the hypothesis is tested, so even a good description could be a bad description according to a different approach.
All this being said, some forms of testing hypotheses–including the ones we’ll use in this work–use interchange interventions[4] . In an interchange intervention, you take two different inputs x1 and x2. You run a forward pass on x2 and cache activations from its forward pass. Then you start running a forward pass on x1, but at one or more nodes in the model you replace the activations with the cached activations from x2 before finishing the forward pass. This corrupts the output of the forward pass of x1. If you have a high level model that perfectly explains the circuit, you should be able to do exactly the same interchange intervention on the high level model (forward pass x2, then forward pass x1 while replacing nodes in the high level model that correspond to replaced nodes in the transformer). If your high level output and low level output under this intervention are the same, then the node that was intervened on in the high level model is a good description of what’s happening in the low level model.
See below for an idealized graphic walkthrough of interchange interventions for the d = a + b + c example we described above. See below the image for more exposition.
In the top panels we show one high level model of three-digit addition; in the bottom panels we show a low level model (a 2-layer transformer with one attention head per layer, and we’re assuming all of the action happens in the attention heads and not in the MLP here). The left panels show the results of standard forward passes that produce the expected output. In the right panels, we show a forward pass for x1 that has undergone an interchange intervention using activations from the x2 forward pass.
In creating this image, I’ve assumed that the high level model is a good description of the transformer circuit, so by intervening on the same node in the high level and low level model, we get the same downstream effects and output. If the high level model were not a good description of the low level circuit, then the low-level model would have output something other than 12, and that would invalidate our high level hypothesis.
Training with SIIT
Strict Interchange Intervention Training (SIIT) trains models to minimize a loss function with three terms. Given a dataset of input-label pairs (x, y), a high-level model MHL, and given two samples from that dataset (x1, y1) and (x2, y2), the loss on those samples is calculated as:
Ltotal=dbaselineLbaseline(x1,y1)+dIITLIIT(x1,MHL(x1);x2)+dSIITLSIIT(x1,y1;x2).
Each of the loss terms performs the same type of calculation (they are all a cross-entropy loss for a categorical task or an MSE loss for a regression task). In words, these loss terms measure:
Lbaseline(x1,y1): the ability of the model to output y1 when given x1 as an input (standard training term, nothing fancy).
LIIT(x1,MHL(x1);x2): the ability of the model to output the same output as a high level model MHL when both the model and MHL undergo a forward pass of x1 with an interchange intervention using cached activations from a forward pass of x2.
LSIIT(x1,y1;x2): the ability of the model to output y1 when performing a forward pass with x1 with an interchange intervention (from an x2 forward pass) on a low-level node that has no function in the high level model.
dbaseline, dIIT, and dSIIT are weights that determine the relative importance of these terms.
For a more detailed description of these loss terms, we refer the reader to section 3 of the IIT paper (Geiger+ 2022) and section 3 of the SIIT paper (Gupta+ 2024).
In training, we try to produce models that achieve 100% accuracy under three metrics:
Baseline Accuracy: The model can produce the proper output, y, given an input, x.
Interchange Intervention Accuracy (IIA), described in section 3 of Geiger+ 2022. The model can produce the same output as the high-level model under an LIIT-style intervention given a pair of inputs x1 and x2.
Strict Interchange Intervention Accuracy (SIIA), described briefly in section 4 of Gupta+ 2024. The model can produce the proper output, y1, given an input, x1, and given an LSIIT-style interchange using the input x2.
Modifications to SIIT training in this work
We’ve made a few changes to the SIIT training presented in Gupta+ 2024 and available in their IIT repo, laid out below.
One vs Three steps per training batch (Adam beta parameters):
Algorithm 1 of Gupta+ 2024 updates model weights three times per batch of training data. Backpropagation is performed and a gradient step is taken after calculating each of the three terms of the loss function (L_baseline, L_IIT, L_SIIT). They found worse performance when taking a single step to update the weights using L_total. I was able to replicate this finding when using the default Adam optimization hyperparameters. We found this troubling: rather than finding a minimum in a single loss landscape, the model was iteratively finding different minima in three different loss landscapes.
The fact that a three-step loss calculation produced better results than a single step loss calculation suggests that some of the terms in the Adam optimizer with “memory” (specifically the Momentum \beta_1 term or the learning rate \beta_2 term) can better optimize models via SIIT when they have a “cool down” between individual steps of the loss terms. To improve the optimizer’s ability to forget past iterations, we can therefore decrease these \beta values (default values of these parameters are \beta_1 = 0.9 and \beta_2 = 0.999) to help the optimizer achieve model convergence using a single loss function and single parameter update per training batch.
I find that decreasing \beta_2 from 0.999 to 0.9 while keeping the default \beta_1 = 0.9 allows me to achieve model convergence using L_total.
Changes to SIIT node sampling
In Gupta+ 2024, L_SIIT is calculated by sampling a single node (an attention head or MLP) that isn’t used in the high level model, then performing an interchange intervention on that single node. This gets closer to strictly enforcing the high / low level mapping IIT tries to create, but we speculate that this could have flaws related to backup behavior.
Imagine you have a model with three attention heads and one MLP. One attention head and the MLP are explicitly used in the high level model and undergo IIT-style interventions. The other two attention heads are not used in the circuit, so SIIT works to penalize those nodes. If only one node is sampled each training step, then the two nodes outside of the circuit “take turns” being active or being intervened on. We speculate that the model could find a way to use those heads to help it achieve lower loss (perhaps they both perform the same task as the node in the circuit, and the downstream MLP looks for consistency between the inputs and their outputs and throws out the results of the one intervened-on head).
To combat the ability of models to learn this backup-like behavior through SIIT, we instead sample a subset of each of the unused nodes for each SIIT loss evaluation. Each unused node is sampled (or not) with a probability of 50%. So in the above example, sometimes the model will have access to both extra attention heads, sometimes it’ll have access to one or the other, and sometimes it’ll have access to none. Our hope is that by intervening randomly on the full subsets of nodes in the SIIT loss, we improve the robustness of the SIIT training method.
Other small changes
Gupta+ 2024 used a fixed learning rate of 10-3, whereas we adopt a linearly decreasing learning rate schedule. We use an initial learning rate of 10-3, then linearly decrease to a learning rate of 2 x 10-4 over the course of training. We still finish training early if IIA and SIIA of 100% are achieved on the validation dataset.
We also use different weights in the loss term. Gupta+ 2024 mostly used {d_baseline = 1, d_IIT = 1, d_SIIT = 0.4}, and we use {d_baseline = 0.4, d_IIT = 1, d_SIIT = 0.4}. Our thinking is that the SIIT term and the baseline term in the loss function are working towards the same task (having the model output the right answer), but they work in opposition: the baseline term tries to use the whole model to get better loss overall, while the SIIT term works to get lower loss while discarding unimportant terms. Rather than overweighting the baseline, we choose to weight these terms equally. The IIT term, on the other hand, is the term that aims to produce consistency with the operations in the high level model, so we leave its magnitude alone. We choose to have d_baseline + d_SIIT < d_IIT because matching the high-level model is the hardest training task, and we want to weight that a bit more heavily.
Finally, when evaluating if the model achieves 100% IIA on the validation data, we loop over all of the nodes in the correspondence between the high- and low-level models, compute IIA for when each individual node is interchanged, then take the mean result over all nodes. Previously, IIA was only calculated for a single node each validation batch.
Building Polysemantic Benchmark Transformers
Our goal is to create transformer models that perform known circuits and have polysemantic neurons. In theory, we could force the models to have polysemantic neurons by mapping high-level computations onto subspaces of low-level MLPs where those subspaces all project onto a single (or a few) neurons. However, we want polysemanticity to arise naturally as a result of training via SGD, so we choose not to constrain the activations paces in that manner in this work.
The overall process that we walk through to create models which we hope exhibit polysemantic neurons (we unfortunately didn’t have the chance to test this) is walked through below. In general, the process looks like this:
Construct a high-level model that can perform a simple task.
Train a circuit into a transformer that performs the same tasks as the high-level model to ensure that it is properly defined in terms of units that a transformer can perform.
Construct another high-level model that just performs multiple of these validated high-level tasks.
Train all of the circuits required for the multi-piece-model into a transformer using SIIT. Importantly, the circuits should overlap within nodes of the transformer (multiple circuits should run through the same MLP).
Step 1: Designing monosemantic benchmarks
We design four benchmarks which each perform one specific algorithmic task. See the appendices for details of these four high level tasks and how we map them onto low-level transformers for training with SIIT. Two of these tasks are simple parenthesis-based tasks (inspired by this post) and two of these tasks were benchmarks in the MATS 5.0 circuits benchmark, but we have constructed by-hand circuits for them which differ from the Tracr circuits that were previously trained.
We break each algorithmic task down into successive computations that can be performed by the nonlinearities present in either a single attention head or in a single MLP layer. We design a “correspondence” which links each node in the high level model with a single attention heads or MLP in the low level model. Then we train the low level transformers using the training techniques we described above in the “Training with SIIT” section and ensure that the high level model can be trained into a transformer using SIIT.
While we only studied four benchmark cases here, we’d be excited about having a larger number of more complex cases added to the benchmark. Superposition occurs when there are more features than the model has capacity to represent and when the cases are sufficiently sparse; to meet the first condition we either need to have a large number of benchmarks in a model or a small d_model (though the former is more interesting!).
Designing polysemantic benchmarks
Phew, we’re through with the (many) preliminaries! Now we can describe the meat of the interesting (or at least novel) work we did here.
Recall that our goal was to train a transformer model which has multiple known circuits running through it. Importantly, the circuits overlap, so e.g., multiple circuits run through and use a single MLP. To achieve this in the framework of IIT/SIIT, we need to construct a high level model that itself is a combination of multiple single high-level benchmarks. We could do this by hand and painstakingly construct a high level model which performs some set of N tasks, but then this method wouldn’t scale and it wouldn’t be very interesting. So: we needed to come up with ways to:
Create a dataset that is a valid combination of datasets for individual tasks.
Create a high level model which is a combination of some set of simple models that the user specifies.
Naturally, both of these come with difficulties.
Combining multiple datasets
Assume we have two datasets which have been curated for two simple algorithmic benchmark tasks.
Dataset 1 is for a parenthesis task and has a vocabulary { “BOS“ : 0, “PAD” : 1, “(” : 2, “)” : 3 }.
Dataset 2 is for a task that looks for repeating tokens and has a vocabulary { “BOS” : 0, “PAD” : 1, “a” : 2, “b” : 3, “c” : 4}.
Furthermore, let’s assume that the context length of samples in Dataset 1 is 10 while the context length of samples in Dataset 2 is 15.
A merged dataset for a polysemantic benchmark can be created by combining these two datasets. We have to make sure that each dataset entry has a flag indicating which dataset (and task) it belongs to, and we need to make sure all entries in the dataset have the same context length. We define the context length of the merged dataset to be the max context length of the individual datasets (15 in the case above) plus 1, so n_ctx for the merged dataset is 16. To make all samples fit into this context, we perform the following operations:
Add a ‘task ID’ token to the beginning of each dataset sample. So: all entries from Dataset 1 will have “0” prepended, and all entries from Dataset 2 will have “1” prepended.
Pad the end of the samples: so, all entries from Dataset 1 will have 5 “PAD” / “1” tokens added to the end.
As far as dataset creation goes – that’s it! Not too bad. But once this dataset starts interacting with the high-level model, there are more challenges; see Difficulty 1 below.
Combining multiple individual tasks into one model
We do all of our work with IIT using TransformerLens’ HookedRootModule interface. This interface is great, because it lets us do interchange interventions really easily! Unfortunately, the restrictions of this interface mean that we have a somewhat inefficient means of combining multiple high level models into one, which we’ll describe below.
To initialize a multi-benchmark model, we pass a list of individual benchmarks, each of which contains a HookedTransformerConfig (describing what size of a transformer it should be trained into) and a Correspondence object (defining which TransformerLens HookPoints in the high level model map and low level model map onto each other). The model makes its own config by taking the max() of d_model, n_layers, etc. of all of the low level models. The model makes it own correspondence by looping through each attention head and MLP that that eventual low-level model will have and noting which hooks in each of its high level models that node will be responsible for. The model also assigns an unused attention head to the task of looking at the task ID for each batch entry.
After instantiation, the model is ready to perform a forward pass using a batch of inputs from a combined dataset like we described above. The forward pass follows this algorithm:
Split the task id (first item of each batch entry) from the rest of the input and store both.
Loop through each high level benchmark
Modify the inputs, if needed, so that they can be directly fed in to each high-level benchmark. See “Difficulty 1” below. Store these modified inputs.
Loop through each high level benchmark with their modified inputs:
Run a forward pass, caching all activations during this forward pass.
Map cached activations from all of the high level models to nodes in the low level models: loop through all of the attention heads and MLPs that this model will map onto.
Using the correspondence created during instantiation, gather all the data in the cache that corresponds to operations that will be computed in this node. Stack all of this data together into a single torch Tensor.
Pass that stacked data together into a HookPoint for ease of interchange interventions
Unpack results of hooked data back into the caches for each HL model.
Loop through each high level benchmark with modified inputs to calculate the output with interventions:
Add hooks for each HookPoint in the high level model and replace forward pass activations with those from the cache.
Run a forward pass with hooks.
Put model output (probabilities) into the appropriate output indices using a boolean mask created from task IDs.
The above describes how to create the high level model. The low level model is created in exactly the same way as we would for a single-circuit benchmark case: Xavier-initialize a transformer and train it with SIIT. The accuracies of the various low level models that we trained are reported at the end of this post in an appendix.
I want to acknowledge that this is messy, and this procedure is the result of us overcoming challenges that we faced along the way in the first and simplest way that occurred to us. I’d love to see someone come up with a cleaner version of this! To help explain why we landed on this procedure, I want to call attention to a couple of difficulties that we ran into below.
Difficulty 1: Interchanges outside of the task vocabulary
One of the biggest difficulties is that two batches might cause problems under the interchange of a task ID. The example below is for two dataset examples from two tasks with different vocabularies as we described above in the “Combining multiple datasets” section. Here’s a sample of the problems we face:
This is a problem! Token 4 doesn’t exist in task 0’s vocabulary! There’s a lot of ways to create a high-level model and dataset to handle this problem:
Always have our batches have the same number of task 1 samples and task 2 samples, and only do interchanges between like samples (unfortunately, this would make it impossible to train a specific node in the low-level model to do the task-id operation, because there would be no interchange information).
Only include tasks with the same d_vocab (restricts the kinds of benchmarks you can build)
Replace out-of-vocab tokens with in-vocab tokens.
We’ve gone with the latter. For now, we’re replacing tokens that are outside of the task vocabulary with randomly sampled tokens from within the task vocab (and we do not include PAD and BOS in the random sampling, so in the above example the ‘4’ would be replaced with a ‘2’ or a ‘3’ representing a ‘(‘ or a ‘)’).
Difficulty 2: Calculating the proper loss.
Some of the more robust benchmarks in circuits-benchmark are regressive tasks where the transformer outputs a float value and which were optimized by a mean-squared-error loss; other tasks are optimized by a cross-entropy loss but are not autoregressive. We generally found that we had to group benchmark cases into three categories:
MSE-optimized
CE-optimized but not autoregressive
CE-optimized and autoregressive
We chose to train benchmarks that are CE-optimized and autoregressive to create benchmarks that are closer to being like real language models, but this choice limited the cases that we found in the literature that we were able to reuse in this work.
Furthermore, even after deciding that we wanted to calculate a per-token CE loss, we still ran into some struggles regarding how much each token entry should be weighted:
If you have one task with n_ctx = 10 and another with n_ctx = 15, should you calculate the loss on every token?
We think no. You shouldn’t calculate the loss on the pad tokens at the end of the n_ctx = 10 simulations.
If you don’t calculate the loss on the pad tokens and just calculate a simple per-token loss, then in the above example (assuming your batch is evenly comprised of both tasks) the loss function receives 60% of its information from the n_ctx=15 case and only 40% of its information from the n_ctx=10 case. But ideally we want to assign these tasks even weight. So really each batch sample should have a mean over relevant tokens in the n_ctx direction taken before finally taking a mean over the batch to get a scalar loss.
For now, we decided to use a fixed n_ctx = 15 for all of our models to avoid these issues, but it’s something to consider in the future.
Future Research Directions
We had hoped to get to the next step with this project this summer (actually training SAEs on these benchmarks and looking for polysemanticity and features), but setting up and then training these benchmarks ended up being a monumental task! Plus, SAE architectures are developing so quickly (gated, top-k, jump-relu) that we decided that in the small amount of time we had, it would be best to focus on making the framework for models to test SAEs (and hopefully making that framework not-completely-incomprehensible to others!).
All that said, the obvious next step is to train some SAEs on these benchmarks and see if there are interpretable SAE features that correspond to the (known) high-level tasks that are being trained into the nodes of these transformers. If those tasks are being represented as features in SAEs, how are they being represented? Is there a single feature for each task that we would expect there to be? Multiple? If you’re interested in doing something like this, here’s a (probably deprecated) notebook we wrote that trains SAEs using SAELens and these wrapper functions, and it should be a pretty quick project to get a few SAEs trained.
Beyond using these benchmarks as a testbed for SAEs, we think there are a lot of really interesting questions about polysemanticity that can be studied using SIIT and the framework we’ve developed here. Specifically, if this framework works like we hope, it allows us to define all of the operations that occur in a given node in a neural network, which allows us to train models that let us ask questions of the form:
How many different tasks can be trained into a single MLP? Presumably MLPs can store exp(d_mlp) features, but are features and tasks one-to-one correlated?
Can a single attention head do multiple completely unrelated tasks? If so, how many completely unrelated tasks can it do?
Can we finetune pretrained language models to enforce and crystallize a known circuit? For example, can we take the gpt2-small IOI circuit and use SIIT or another technique to make it so that ~100% (rather than just most) of the computation is done in the “known” circuit?
We’d also be interested in research that interrogates whether IIT and SIIT actually do what we want them to do[5]. What would be the robust questions to ask to determine if the “unused” nodes in the transformer are actually not contributing to the answer in a meaningful way? Is there a better way to train known circuits into transformers in a less synthetic way than e.g., Tracr compilation? And, what are better ways to construct our high-level polysemantic benchmark other than the stuff that we’ve done to get things working this summer?
Lots of great questions – wish we had time to explore them all. If you’re interested in any of them, please feel free to reach out and dig in!
Acknowledgments
Evan is grateful to lots of people for helping him get his footing in the AI safety field; in particular: Adam Jermyn for being a great friend and mentor, Joseph Bloom for his mentorship, kindness, and great advice, Eoin Farrell and Aaron Scher for lots of great conversations and collaboration, Jason Hoelscher-Obermaier and Clement Neo for their mentorship during an Apart Lab sprint, and Adrià Garriga-Alonso for his mentorship during MATS. Evan’s also grateful to many more people he’s met in the AIS community over the course of the past year for some really excellent conversations – there are too many of you to list, but thanks for taking the time to chat! Evan was both a MATS scholar and KITP postdoctoral scholar during the time that this work was completed, so this work was partially supported by NSF grant PHY-2309135 and Simons Foundation grant (216179, LB) to the Kavli Institute for Theoretical Physics (KITP) and by MATS.
Code & Model Availability
Code used to define high level models/datasets, and notebooks for training and loading trained models are available in this git repository. Trained models are available online in this huggingface repo (see the git repo for a notebook that loads them).
Appendices
Custom High Level Models
Parentheses Balancer
This algorithmic task was inspired by this post. The task: Given a string that consists of open “(“ and closed “)” parentheses, determine if the string is balanced: that is, every open parenthesis is eventually closed, and no closed parenthesis appear when there is no open parenthesis that they correspond to. This can be broken down into a “horizon” test and an “elevation” test: first map “(“ → +1 and “)” → −1. To get the elevation, take the cumulative sum of these values over your string. The horizon test is passed if the string stays “above the horizon” (the current elevation is never less than zero). The elevation test is passed if the elevation is zero at the end of the string.
This task is a bit tricky to train into a transformer, because a good dataset isn’t trivial to create. A random string of “(“ and “) sampled uniformly will almost always be an imbalanced string, and a transformer can get ~99% accuracy just by always saying that the string is not balanced. To get around this, we created custom training datasets that consisted of:
25% balanced strings
25% strings that pass the horizon test but fail the elevation test
25% strings that pass the elevation test but fail the horizon test
25% strings that pass both tests
Note that we distinguish between different types of failures to ensure that interchange interventions would have a higher probability of changing the downstream output than they would for randomly generated failures (which would mostly fail both tests).
In order to make training this task occur on a token-by-token basis (like with the other tasks below), we evaluate whether the string so far is balanced (or not) at each token position.
A mapping between the high level model we designed and where we train them into a low level model is shown in the figure below. Orange arrows show where high level model nodes map into the low level model.
Left Greater
The task: Given a string that consists of open “(“ and closed “)” parentheses, determine if there are more left parenthesis than right parentheses. This is a simple task, and a dataset consisting of randomly-sampled 0s and 1s can train this task fine.
A mapping between the high level model we designed and where we train them into a low level model is shown in the figure below. Orange arrows show where high level model nodes map into the low level model.
Duplicate Remover
This is inspired by case 19 in circuits-benchmark. The task: Given a string that consists of “a”, “b”, and “c” tokens, remove any instances of duplicated tokens, so “a a b c c a” becomes “a PAD b c PAD a”. This is a simple task, and a dataset consisting of randomly-sampled tokens is fine.
A mapping between the high level model we designed and where we train them into a low level model is shown in the figure below. Orange arrows show where high level model nodes map into the low level model.
Unique Extractor
This is inspired by case 21 in circuits-benchmark. The task: The task: Given a string that consists of “a”, “b”, and “c” tokens, remove any instances of tokens that have previously appeared, so “a a b c c a b c” becomes “a PAD b c PAD PAD PAD PAD”. This is a simple task, and a dataset consisting of randomly-sampled tokens is fine.
A mapping between the high level model we designed and where we train them into a low level model is shown in the figure below. Orange arrows show where high level model nodes map into the low level model.
Post-training accuracies
Below is a table quoting various accuracies achieved by the pretrained models in this huggingface repo. These accuracies are calculated on the validation set used during model training, which is 2000 samples with n_ctx = 15, times the number of circuits in the model (2000 samples for a monosemantic model, 4000 samples for a polysemantic model with two circuits, etc.).
Specifically models that “play the training game” well, see e.g., Cotra 2022, Ngo 2022 for some narratives of such models.
Upon reviewing these tasks, some of them do not work as intended, and some of them are pure memorization; we found their cases 3, 4, 18, 19, and 21 to be the most robust of their currently public models.
There are other techniques that aren’t covered in that post, e.g., Distributed Alignment Search and SAE based circuit identification.
I referred to “interchange interventions” as “activation patching” before I came into MATS; they’re roughly the same, and if you want an intro to how to use a technique like this, the IOI ARENA notebook is great.
Maybe there’s a bit of a chicken-and-egg problem here. What if the SAEs trained on our benchmarks don’t produce features with the expected high-level interpretation. Does that mean that the SAEs are bad, or are the SAEs actually good and they’re telling us SIIT isn’t working as intended? How do we disentangle those outcomes?