(I haven’t had the chance to read part 3 in detail, and I also haven’t checked the proofs except insofar as they seem reasonable on first viewing. Will probably have a lot more thoughts after I’ve had more time to digest.)
This is very cool work! I like the choice of U-AND task, which seems way more amenable to theoretical study (and is also a much more interesting task) than the absolute value task studied in Anthropic’s Toy Model of Superposition (hereafter TMS). It’s also nice to study this toy task with asymptotic theoretical analysis as opposed to the standard empirical analysis, thereby allowing you to use a different set of tools than usual.
The most interesting part of the results was the discussion on the universality of universal calculation—it reminds me of the interpretations of the lottery ticket hypothesis that claim some parts of the network happen to be randomly initialized to have useful features at the start.
Some examples that are likely to be boolean-interpretable are bigram-finding circuits and induction heads. However, it’s possible that most computations are continuous rather than boolean[31].
My guess is that most computations are indeed closer to continuous than to boolean. While it’s possible to construct boolean interpretations of bigram circuits or induction heads, my impression (having not looked at either in detail on real models) is that neither of these cleanly occur inside LMs. For example, induction heads demonstrate a wide variety of other behavior, and even on induction-like tasks, often seem to be implementing induction heuristics that involve some degree of semantic content.
Consequently, I’d be especially interested in exploring either the universality of universal calculation, or the extension to arithmetic circuits (or other continuous/more continuous models of computation in superposition).
Some nitpicks:
The post would probably be a lot more readable if it were chunked into 4. The 88 minute read time is pretty scary, and I’d like to comment only on the parts I’ve read.
Section 2:
Two reasons why this loss function might be principled are
If there is reason to think of the model as a Gaussian probability model
If we would like our loss function to be basis independent
A big reason to use MSE as opposed to eps-accuracy in the Anthropic model is for optimization purposes (you can’t gradient descent cleanly through eps-accuracy).
A few researchers (at Apollo, Cadenza, and IHES) posted this document today (22k words, LW says ~88 minutes).
They propose two toy models of computation in superposition.
First, they posit a MLP setting where a single layer MLP is used to compute the pairwise ANDs of m boolean input variables up to epsilon-accuracy, where the input is sparse (in the sense that l < m are active at once). Notably, in this set up, instead of using O(m^2) neurons to represent each pair of inputs, you can instead use O(polylog(m)) neurons with random inputs, and “read off” the ANDs by adding together all neurons that contain the pair of inputs. They also show that you can extend this to cases where the inputs themselves are in superposition, though you need O(sqrt(m)) neurons. (Also, insofar as real neural networks implement tricks like this, this probably incidentally answers the Sam Mark’s XOR puzzle.)
They then consider a setting involving the QK matrix of an attention head, where the task is to attend to a pair of activations in a transformer, where the first activation contains feature i and the second feature j. While the naive construction can only check for d_head bigrams, they provide a construction involving superposition that allows the QK matrix to approximately check for Theta(d_head * d_residual) bigrams (that is, up to ~parameter count; this involves placing the input features in superposition).
If I’m understanding it correctly, these seem like pretty cool constructions, and certainly a massive step up from what the toy models of superposition looked like in the past. In particular, these constructions do not depend on human notions of what a natural “feature” is. In fact, here the dimensions in the MLP are just sums of random subsets of the input; no additional structure needed. Basically, what it shows is that for circuit size reasons, we’re going to get superposition just to get more computation out of the network.
I’d be especially interested in exploring either the universality of universal calculation
Do you mean the thing we call genericity in the further work section? If so, we have some preliminary theoretical and experimental evidence that genericity of U-AND is true. We trained networks on the U-AND task and the analogous U-XOR task, with a narrow 1-layer MLP and looked at the size of the interference terms after training with a suitable loss function. Then, we reinitialised and froze the first layer of weights and biases, allowing the network only to learn the linear readoff directions, and found that the error terms were comparably small in both cases.
This figure is the size of the errors for d=d0=100 (which is pretty small) for readoffs which should be zero in blue and one in yellow (we want all these errors to be close to zero).
This suggests that the AND/XOR directions were ϵ-linearly readoffable at initialisation, but the evidence at this stage is weak because we don’t have a good sense yet of what a reasonable value of ϵ is for considering the task to have been learned correctly: to answer this we want to fiddle around with loss functions and training for longer. For context, an affine readoff (linear + bias) directly on the inputs can read off f1∧f2 with (→f1+→f2)/2−1/4, which has an error of ϵ=1/4. This is larger than all but the largest errors here, and you can’t do anything like this for XOR with affine readoff.
After we did this, Kaarel came up with an argument that networks randomly initialised with weights from a standard Gaussian and zero bias solve U-AND with inputs not in superposition (although it probably can be generalised to the superposition case) for suitable readoffs. To sketch the idea:
Let Wi be the vector of weights from the ith input to the neurons. Then consider the linear readoff vector with kth component given by:
(−−−→ANDi,j)k=α(1Wik>0∧1Wjk>0)+β1Wik>0+γ1Wjk>0+δ
where 1 is the indicator function. There are 4 free parameters here, which are set by 4 constraints given by requiring that the expectation of this vector dotted with the activation vector has the correct value in the 4 cases fi,fj∈{0,1}. In the limit of large d the value of the dot product will be very close to its expectation and we are done. There are a bunch of details to work out here and, as with the experiments, we aren’t 100% sure the details all work out, but we wanted to share these new results since you asked.
A big reason to use MSE as opposed to eps-accuracy in the Anthropic model is for optimization purposes (you can’t gradient descent cleanly through eps-accuracy).
We’ve suggested that perhaps it would be more principled to use something like Lp loss for larger p than 2, as this is closer to ϵ-accuracy. It’s worth mentioning that we are currently finding that the best loss function for the task seems to be something like Lp with extra weighting on the target values that should be 1. We do this to avoid the problem that if the inputs are sparse, then the ANDs are sparse too, and the model can get good loss on Lp (for low p) by sending all inputs to the zero vector. Once we weight the ones appropriately, we find that lower values of p may be better for training dynamics.
or the extension to arithmetic circuits (or other continuous/more continuous models of computation in superposition)
How are you setting p when d0=100? I might be totally misunderstanding something but log2(d0)/√d≈2.12 at d0=d=100 - feels like you need to push d up towards like 2k to get something reasonable? (and the argument in 1.4 for using 1log2d0 clearly doesn’t hold here because it’s not greater than log2d0d1/kfor this range of values).
So, all our algorithms in the post are hand constructed with their asymptotic efficiency in mind, but without any guarantees that they will perform well at finite d. They haven’t even really been optimised hard for asymptotic efficiency—we think the important point is in demonstrating that there are algorithms which work in the large d limit at all, rather than in finding the best algorithms at any particular d or in the limit. Also, all the quantities we talk about are at best up to constant factors which would be important to track for finite d. We certainly don’t expect that real neural networks implement our constructions with weights that are exactly 0 or 1. Rather, neural networks probably do a messier thing which is (potentially substantially) more efficient, and we are not making predictions about the quantitative sizes of errors at a fixed d.
In the experiment in my comment, we randomly initialised a weight matrix with each entry drawn from N(0,1), and set the bias to zero, and then tried to learn the readoff matrix R, in order to test whether U-AND is generic. This is a different setup to the U-AND construction in the post, and I offered a suggestion of readoff vectors for this setup in the comment, although that construction is also asymptotic: for finite d and a particular random seed, there are almost definitely choices of readoff vectors that achieve lower error.
FWIW, the average error in this random construction (for fixed compositeness; a different construction would be required for inputs with varying compositeness) is (we think) Θ(1/√d) with a constant that can be found by solving some ugly gaussian integrals but I would guess is less than 10, and the max error is Θ(logd/√d) whp, with a constant that involves some even uglier gaussian integrals.
(I haven’t had the chance to read part 3 in detail, and I also haven’t checked the proofs except insofar as they seem reasonable on first viewing. Will probably have a lot more thoughts after I’ve had more time to digest.)
This is very cool work! I like the choice of U-AND task, which seems way more amenable to theoretical study (and is also a much more interesting task) than the absolute value task studied in Anthropic’s Toy Model of Superposition (hereafter TMS). It’s also nice to study this toy task with asymptotic theoretical analysis as opposed to the standard empirical analysis, thereby allowing you to use a different set of tools than usual.
The most interesting part of the results was the discussion on the universality of universal calculation—it reminds me of the interpretations of the lottery ticket hypothesis that claim some parts of the network happen to be randomly initialized to have useful features at the start.
My guess is that most computations are indeed closer to continuous than to boolean. While it’s possible to construct boolean interpretations of bigram circuits or induction heads, my impression (having not looked at either in detail on real models) is that neither of these cleanly occur inside LMs. For example, induction heads demonstrate a wide variety of other behavior, and even on induction-like tasks, often seem to be implementing induction heuristics that involve some degree of semantic content.
Consequently, I’d be especially interested in exploring either the universality of universal calculation, or the extension to arithmetic circuits (or other continuous/more continuous models of computation in superposition).
Some nitpicks:
The post would probably be a lot more readable if it were chunked into 4. The 88 minute read time is pretty scary, and I’d like to comment only on the parts I’ve read.
Section 2:
A big reason to use MSE as opposed to eps-accuracy in the Anthropic model is for optimization purposes (you can’t gradient descent cleanly through eps-accuracy).
Section 5:
This should be labeled as section 5.
Appendix to the Appendix:
(TeX compilation failure)
Also, here’s a summary I posted in my lab notes:
Thanks for the kind feedback!
Do you mean the thing we call genericity in the further work section? If so, we have some preliminary theoretical and experimental evidence that genericity of U-AND is true. We trained networks on the U-AND task and the analogous U-XOR task, with a narrow 1-layer MLP and looked at the size of the interference terms after training with a suitable loss function. Then, we reinitialised and froze the first layer of weights and biases, allowing the network only to learn the linear readoff directions, and found that the error terms were comparably small in both cases.
This figure is the size of the errors for d=d0=100 (which is pretty small) for readoffs which should be zero in blue and one in yellow (we want all these errors to be close to zero).
This suggests that the AND/XOR directions were ϵ-linearly readoffable at initialisation, but the evidence at this stage is weak because we don’t have a good sense yet of what a reasonable value of ϵ is for considering the task to have been learned correctly: to answer this we want to fiddle around with loss functions and training for longer. For context, an affine readoff (linear + bias) directly on the inputs can read off f1∧f2 with (→f1+→f2)/2−1/4, which has an error of ϵ=1/4. This is larger than all but the largest errors here, and you can’t do anything like this for XOR with affine readoff.
(−−−→ANDi,j)k=α(1Wik>0∧1Wjk>0)+β1Wik>0+γ1Wjk>0+δAfter we did this, Kaarel came up with an argument that networks randomly initialised with weights from a standard Gaussian and zero bias solve U-AND with inputs not in superposition (although it probably can be generalised to the superposition case) for suitable readoffs. To sketch the idea:
Let Wi be the vector of weights from the ith input to the neurons. Then consider the linear readoff vector with kth component given by:
where 1 is the indicator function. There are 4 free parameters here, which are set by 4 constraints given by requiring that the expectation of this vector dotted with the activation vector has the correct value in the 4 cases fi,fj∈{0,1}. In the limit of large d the value of the dot product will be very close to its expectation and we are done. There are a bunch of details to work out here and, as with the experiments, we aren’t 100% sure the details all work out, but we wanted to share these new results since you asked.
We’ve suggested that perhaps it would be more principled to use something like Lp loss for larger p than 2, as this is closer to ϵ-accuracy. It’s worth mentioning that we are currently finding that the best loss function for the task seems to be something like Lp with extra weighting on the target values that should be 1. We do this to avoid the problem that if the inputs are sparse, then the ANDs are sparse too, and the model can get good loss on Lp (for low p) by sending all inputs to the zero vector. Once we weight the ones appropriately, we find that lower values of p may be better for training dynamics.
We agree and are keen to look into that!
Thanks—fixed.
Fascinating, thanks for the update!
How are you setting p when d0=100? I might be totally misunderstanding something but log2(d0)/√d≈2.12 at d0=d=100 - feels like you need to push d up towards like 2k to get something reasonable? (and the argument in 1.4 for using 1log2d0 clearly doesn’t hold here because it’s not greater than log2d0d1/kfor this range of values).
So, all our algorithms in the post are hand constructed with their asymptotic efficiency in mind, but without any guarantees that they will perform well at finite d. They haven’t even really been optimised hard for asymptotic efficiency—we think the important point is in demonstrating that there are algorithms which work in the large d limit at all, rather than in finding the best algorithms at any particular d or in the limit. Also, all the quantities we talk about are at best up to constant factors which would be important to track for finite d. We certainly don’t expect that real neural networks implement our constructions with weights that are exactly 0 or 1. Rather, neural networks probably do a messier thing which is (potentially substantially) more efficient, and we are not making predictions about the quantitative sizes of errors at a fixed d.
In the experiment in my comment, we randomly initialised a weight matrix with each entry drawn from N(0,1), and set the bias to zero, and then tried to learn the readoff matrix R, in order to test whether U-AND is generic. This is a different setup to the U-AND construction in the post, and I offered a suggestion of readoff vectors for this setup in the comment, although that construction is also asymptotic: for finite d and a particular random seed, there are almost definitely choices of readoff vectors that achieve lower error.
FWIW, the average error in this random construction (for fixed compositeness; a different construction would be required for inputs with varying compositeness) is (we think) Θ(1/√d) with a constant that can be found by solving some ugly gaussian integrals but I would guess is less than 10, and the max error is Θ(logd/√d) whp, with a constant that involves some even uglier gaussian integrals.