Computational Superposition in a Toy Model of the U-AND Problem
Thanks to @Linda Linsefors and @Kola Ayonrinde for reviewing the draft.
tl;dr:
I built a toy model of the Universal-AND Problem described in Toward A Mathematical Framework for Computation in Superposition (CiS).
It successfully learnt a solution to the problem, providing evidence that such computational superposition[1] can occur in the wild. In this post, I’ll describe the circuits I’ve found that model learns to use, and explain why they work.
The learned circuit is different from the construction given in CiS. The paper gives a sparse construction, while the learned model is dense—every neuron is useful for computing every possible output.
I sketch out how this works and why this is plausible and superior to the hypothesized construction.
This result has implications for other computational superposition problems. In particular, it provides a new answer for why XOR so commonly appears in linear representations.
The Universal-AND Problem
The Universal-AND problem considers a set of boolean inputs taking values in . The inputs are -sparse, i.e. at most are active at once, . The problem is to create a one layer model that computes a vector of size , called the neuron activations, such that it’s possible to read out every using an appropriate linear transform.
The problem has elaborations where are stored in superposition in an activation vector of size , or noise is introduced, but I have not directly experimented on this.
The idea of this problem is that for values of there is not enough compute or output bandwidth to compute every possible AND pair separately. But if we assume , then the model can take advantage of the sparsity of the inputs to re-use the same model weights for unrelated calculations. This phenomenon is called Computational Superposition[1]. This is thought to occur in real-life models, in analogy to Activations in Superposition where sparsity is leveraged to store many more pieces of data in a vector than could be otherwise.
I model this as:
Where . In other words, and describe the “compute layer”, while and describe the “readout layer”.
In this formulation, the output has dimension corresponding to ordered pairs selecting pairs of inputs. So the target output is
The diagonals are therefore unused and only exist to make indexing convenient.
Related Work
Toward A Mathematical Framework for Computation in Superposition supplies the original statement of the U-AND problem, and a sparse construction that solves it with provable bounds on the error term. But I find in practice the learned model differs from the sparse construction in some key ways.
Circuits in Superposition: Compressing many small neural networks into one also looks at a problem statement that forces computational superposition and also uses a sparse construction to calculate theoretical bounds.
Superposition is not “just” neuron polysemanticity discusses the distinction between superposition and other possible reasons neurons might be polysemantic. The circuit found in this article is a good example of Example 1 - non-neuron aligned features.
What’s up with LLMs representing XORs of arbitrary features? and comments go into some depth on linear readout of Boolean circuits as observed in the wild. I suggest an alternative explanation for these observations.
Training
I trained the above two layer model on synthetic data where randomly exactly values of are active (i.e. take value ). I used RMS loss.
I sample uniformly, so test cases of the form where much more common than and . I up-weighted the loss so that the expected contribution from each of those three cases is equal. This encourages the model to focus on the few active results rather than the vast sea of inactive results. We can justify this because the second layer is intended as a “readout” layer. For training purposes, we read out every possible pair every time, but if the first layer was actually embedded in part of a larger network, presumably not all pairs are actually desired.
Some weight decay is used to regularize the network. This matches real-world training runs, and encourages the model to focus on the optimal circuits. I trained each model with 6000 epochs of 10k batches to encourage this.
I typically ran with , but experimented with other values too.
The full details can be found in the corresponding notebook.
Results
This section contains empirical results, and can be skipped if you just want the crisp formula I extract from the behaviour.
We find that at low values, the model does find solutions that are capable of solving the U-AND problem, even extending to extremely low values of . The model weights take on a simple pattern of binary weights described below.
At higher values of (starting at 10 for ), the model starts to prefer more degenerate solutions, particularly for . It either learns pure-additive circuits that roughly correspond to or it fails to update the layer one weights at all, and relies entirely on readout weights.
The Binary Weighted Circuit
The model generally tends towards neuron weights that are binary, i.e. takes on only one of two different values. As far as I can tell, the choice is randomized, with a roughly even balance between them.
Expressed mathematically
I discuss why this is an effective choice in Circuit Analysis.
Binary Weight Charts
We established that with sufficient overtraining, the model trends towards binary weights with just three parameters , , . That means we can a scatter plot with one point per neuron. Error bars show 90th-percentile deviation from the weight being near to either the upper/lower bound.
Readout Charts
Another way of viewing the neurons is in terms of how they are read out by matrix .
I pick two arbitrary inputs (say and ), and plot each neuron in a scatter chart based on their weights (i.e. , ). I then color the neurons based on their readout weight for (i.e. ).
The four corners of this chart correspond to classes A (top right), C (bottom left), B1 and B2 as described in Circuit Analysis. Class C has higher weights per-neuron as there are fewer neurons in that class than class A.
Even in cases where the weights don’t form a Binary Weighted Circuit, these charts give a clear indication of how readout works.
You can tell this is a different regime / circuit to as there is no neurons with negative input weights but positive readout weights (yellow dots in the bottom left).
Variant Experiments
I also evaluated the same models with randomized weights, then optimizing only the readout matrix. This variant exhibits similar loss, just worse by a constant factor. This matches observations from comments on the CiS post that random weights already have a number of desirable properties.
I changed the synthetic input to a mix of sparsity values. In this variant, you no longer see such a neat binary in the weights, but this is not that surprising at different weight values will work best at different sparsity levels. I did not spot any surprising differences here.
I tried using a model that forces the weights matrix to match the binary weights pattern, with some smoothing to allow for gradients. This model has very few free parameters, so I thought it would train very well. But it converged slower than simply training the whole matrix, so I found nothing valuable here.
Analysis
Circuit Analysis
I found for a wide range of parameters, the model tends towards binary weights where for any given neuron, the weights it uses only take on two different values, with no pattern. Expressed mathematically
This is very close to the CiS Construction, which followed this exact pattern, with values
In the CiS Construction is small, and is zero, meaning the neurons were only sparsely connected. This property was key in proving an upper bound on the loss of the model.
But in the learned Binary Weighted Circuit we see quite different values. I found these values representative[2]:
Notation: When the values are a constant, I’ll drop the subscript.
Unlike the original construction, this is dense. Every neuron reads a significant value for every possible input.
Let’s explain exactly how neurons in this structure can be used to read out a good approximation of AND. Without loss of generality, let’s consider the first two inputs, i.e. we want to read out from the activations of neurons that all share the same values of //, and differ only by the random choice in the weights.
We can subdivide the neurons into 4 classes based on the random choice of weight for the first two inputs.
For simplicity of analysis, let’s ignore the other inputs shown as (called the interference terms in the CiS paper), and set , though you’ll see these details won’t be too critical.
We can then draw up a truth table for the results of each class of neuron for the 4 possible values of .
A | B1 | B2 | C | |||||
---|---|---|---|---|---|---|---|---|
0 | 0 | 0.05 | 0.05 | 0.05 | 0.05 | 0 | ||
0 | 1 | 0.15 | 0 | 0.15 | 0 | 0 | ||
1 | 0 | 0.15 | 0.15 | 0 | 0 | 0 | ||
1 | 1 | 0.25 | 0 | 0 | 0 | 1 |
Taking the right linear combination of the 4 truth tables, we can recreate the AND truth table. This linear combination is a close match for the ones seen in Readout Charts.
Taking a linear combination will always be possible if the 4 classes are linearly independent, which is true for quite a wide range of choices for the key parameters[3].
Of course, the interference term means that these truth tables are not accurate. But in this toy model, the different inputs are almost completely independent. So the interference term can be modeled as noise. This will pull the 4 classes towards colinearity, but for appropriate bias there will always be a difference that can be exploited.
As there are many neurons in each class, the readout matrix can average over them, reducing the noise to reasonable levels. This is similar to the proofs on noise bounds in CiS, and I discuss a bit further in Circuit Efficiency.
XOR Circuits
Note that linear combinations of these classes can recreate any desired truth table[4] - we only got AND because of our choice of readout matrix. This may go to some extent to explain this observation that is often possible to readout XOR.
I explain the XOR readout phenomenon in two parts. Firstly, it seems will be readoutable whenever there is a demand for boolean logic involving and , regardless of whether XOR is specifically useful or used. This is because the circuit for learning any truth table is the same, only differing by readout.
But secondly, and might not even be need to be related logically. The circuit described above involves all neurons as that allows errors to average out. In the toy example, we actually are testing all possible pairs. But if we were only interested in pairs drawn from a subset, they would still want to use all the neurons to maximize error averaging. So if there were two disjoint AND problems that shared the same model, the model might find that it is efficient to learn one big binary weighted circuit that covers inputs from both problems. It would only readout what is actually needed, but it would be possible to find readouts that connect both sets of inputs.
Generalizability of Linear Probes
Sam Marks makes the argument that if XOR feature directions are readily available, then a linear probe that learns only on data where , then the probe will fail to generalize to because it’s just as likely to identify as it is , but they give opposite answers when out of distribution.
For the parameters I used above, we find the following readout weights (for the total of each class):
So although both formulas are learnable, XOR requires larger weights, so will be disfavoured by regularization[5]. So I predict probes would still work when generalized, despite XOR being readoutable.
Circuit Efficiency
So why is this dense Binary Weighted Circuit preferred over the CiS Construction described in CiS? Possibly it’s just a natural effect of weight decay, which tends to encourage circuits to spread out to keep individual weights low in magnitude.
I think that this circuit produces lower loss values (absolutely and asymptotically). I’ll give an approximating argument.
To recap, the CiS Construction fills the weight matrix with zeros and ones, with the ones being at probability . While the Binary Weighted Circuit fills the weight matrix with and , with and .
Let’s return to the the view of a neuron from the previous section where we have two distinguished inputs, w.l.og.
Here I’ve explicitly labelled the interference term
Most of the entries of don’t contribute, as will be zero. There will be at most entries, each randomly contributing or depending on . So we can approximate as binomially distributed[6].
We can treat the collection of as independent as the randomness comes from the independent choices of .
It’s hard to characterize exactly how the ReLU function interacts with , but I’m going to assume it affects variance by a constant factor, , giving:
Now we can use this estimate to compute the variance of the readout values.
In the CiS Construction, the readout is the mean value of all neurons that are in class A (i.e. have a weight of 1 for and ). There are such neurons (in expectation) so we can divide by that to get the variance of the mean.
Meanwhile in the Binary Weighted Circuit, we pick readout weights based on the inverse of the truth table matrix. It can be shown that the total weight for each class is approximately [7].
So we have class A neurons, each with readout weight . This gives total variance
Both variance formulas are pretty similar in terms of constants and asymptotic behaviour, which is unsurprising as so far we’ve just evaluated different choices of and . The dense case is worse by a factor of , and includes some extra terms that were previously ignored because is small in the CiS Construction.
To compare these properly, we need to consider the different values of . For the CiS Construction, this was asymptotically small , while in the Binary Weighted Circuit was a constant .
When we work out the asymptotic behaviour we get
So the choice of circuit depends on , with the Binary Weighted circuit preferred when grows slower than .
In retrospect, I think this result makes sense. By increasing , you are “using” more neurons. This increases the variance of individual neurons but gives you many more to average over. determines that per-neuron noise, so determines the trade-off. The CiS Construction merely aimed to minimize the interference term of a single neuron, as this was critical for establishing provable error bounds.
Additionally, aside from accuracy considerations, dense circuits are more likely to result from training than equivalent sparse ones as they score better for regularization loss. Or in other words, adding weight decay adds an inductive bias towards “spreading out” circuits as much as possible.
Is This Just Feature Superposition?
In a sense, the circuit found here bears a lot of resemblance to simply randomly embedding the inputs in dimensional space. In both cases, you get a mix of neurons with different response patterns, and you can approximate the AND operation by taking a dense linear combination of the neurons.
I re-ran my code with frozen weights for the calculation layer to test this, and you do get a similar pattern of readouts.
The main difference between a randomized model and the one I found learnt was that the weights take on specific binary values. But this is unlikely to replicate outside of toy models. Using binary weights simplifies analysis and has an improved loss, but does not appear to be better than random uniform values except by a constant factor.
Viewed in this paradigm, the key takeaways of the results section are slightly different:
Mixing / superposing input variables is useful for computation, even when there is no pressure to store activations in a narrow dimension.
If you have a random embedding of sparse inputs and sum and take ReLU then any boolean logic can be readout with the appropriate linear combination.
Limitations and Future Work
Toy models by their nature are not very representative of real situations. That said I think there are some ways this model could be expanded that:
We should explore input in superposition or at least some input noise
Haven’t explored loss that uses a subset of all AND pairings, or has an uneven weighting for input co-occurence
Is RMS a realistic way to score this?
I would also like to see the model run for a sufficient range of parameters so that scaling laws can be established. Does the loss actually correspond to the variance formula I derive? What is the tradeoff of sparsity and number of neurons?
I’ve established that for this simple problem, the it is better to use a dense circuit, than a sparse one, even though the input features are sparse. That naturally suggests that there may be dense constructions for other problems, like sparse boolean circuits or circuit compression. There may also be implications for interpretability techniques which depend on assumptions of circuit sparsity to work.
Finally, I think a logical step would be to see if a similar sort of circuitry can be found in a real LLM. There’s indirect evidence, but no solid example.
Conclusion
This work provides some empirical evidence for Toward A Mathematical Framework for Computation in Superposition. It’s clear that many boolean boolean logic equations can be computed from from one set of weights, even for small dimension sizes.
We have solidified the understanding of how basic boolean logic is performed, specifically despite sparse features, there is a tendency for dense circuitry to be used. The asymptotic error is characterized in terms of sparsity, and hidden dimension size.
We’ve discussed applicability to XOR representations and linear probes.
Dense circuits like these challenge the common assumption that circuits can be found by finding a sparse subset of connections inside a larger model. If it’s better to have many shared noisy calculations than smaller set of isolated reliable ones, then a different set of techniques is needed to detect them.
- ^
I draw a distinction between a model’s ability to work directly on activations in superposition (computation in superposition) and the ability to share weights between multiple overlapping circuits (computational superposition).
Though closely related, I think these can be investigated separately. - ^
As the readout matrix can supply an arbitrary positive scaling, the only important details of / are their signs and ratio.
The magnitude tends to be around as this minimizes regularization loss (weight decay) in a two layer network.
- ^
@jake_mendel uses a very similar argument in this comment.
- ^
The design also extends to truth tables of more than two inputs without much difficulty.
- ^
Actually when using norm there are some choices of where this result reverses. But the point still holds that one of the two possible probes will be favoured.
- ^
This approximation elides the correlation between and , , which for high values of can be important, but I don’t think effects the circuits examined significantly.
- ^
This comes from having variance proportional to , but the neuron classes only differ from each other by a constant translation of at most . As increases, the classes ReLU zero-points are at increasingly similar points on the probability distribution.
I also verified this empirically.
- ^
Even at I found no evidence of the model learning a naive 1:1 neuron:output strategy. I believe that the presense of weight decay makes the binary weighted circuit preferable even though it is a bit more noisy.
Forgot to tell you this when you showed me the draft: The comp in sup paper actually had a dense construction for UAND included already. It works differently than the one you seem to have found though, using Gaussian weights rather than binary weights.
Yes I don’t think the exact distribution of weights Gaussian/uniform/binary really makes that much difference, you can see the difference in loss in some of the charts above. The extra efficiency probably comes from the fact that every neuron contributes to everything fully—with Gaussian, sometimes the weights will be close to zero.
Some other advantages:
* But they are somewhat easier to analyse than gaussian weights.
* They can be skewed (p≠0.5) which seems advantageous for an unknown reason. Possibly it makes AND circuits better at the expense of other possible truth tables.