I think we may be close to figuring out a general mathematical framework for circuits in superposition.
I suspect that we can get a proof that roughly shows:
If we have a set of T different transformers, with parameter counts N1,…NT implementing e.g. solutions to T different tasks
And those transformers are robust to size ϵ noise vectors being applied to the activations at their hidden layers
Then we can make a single transformer with N=O(∑Tt=1Nt) total parameters that can do all T tasks, provided any given input only asks for k<<T tasks to be carried out
Crucially, the total number of superposed operations we can carry out scales linearly with the network’s parameter count, not its neuron count or attention head count. E.g. if each little subnetwork uses n neurons per MLP layer and m dimensions in the residual stream, a big network with d1 neurons per MLP connected to a d0-dimensional residual stream can implement about O(d0d1mn) subnetworks, not just O(d1n).
This would be a generalization of the construction for boolean logic gates in superposition. It’d use the same central trick, but show that it can be applied to any set of operations or circuits, not just boolean logic gates. For example, you could superpose an MNIST image classifier network and a modular addition network with this.
So, we don’t just have superposed variables in the residual stream. The computations performed on those variables are also carried out in superposition.
Remarks:
What the subnetworks are doing doesn’t have to line up much with the components and layers of the big network. Things can be implemented all over the place. A single MLP and attention layer in a subnetwork could be implemented by a mishmash of many neurons and attention heads across a bunch of layers of the big network. Call it cross-layer superposition if you like.
This framing doesn’t really assume that the individual subnetworks are using one-dimensional ‘features’ represented as directions in activation space. The individual subnetworks can be doing basically anything they like in any way they like. They just have to be somewhat robust to noise in their hidden activations.
You could generalize this from T subnetworks doing unrelated tasks to T “circuits” each implementing some part of a big master computation. The crucial requirement is that only k<<T circuits are used on any one forward pass.
I think formulating this for transformers, MLPs and CNNs should be relatively straightforward. It’s all pretty much the same trick. I haven’t thought about e.g. Mamba yet.
Implications if we buy that real models work somewhat like this toy model would:
There is no superposition in parameter space. A network can’t have more independent operations than parameters. Every operation we want the network to implement takes some bits of description length in its parameters to specify, so the total description length scales linearly with the number of distinct operations. Overcomplete bases are only a thing in activation space.
There is a set of T Cartesian directions in the loss landscape that parametrize the T individual superposed circuits.
If the circuits don’t interact with each other, I think the learning coefficient of the whole network might roughly equal the sum of the learning coefficients of the individual circuits?
If that’s the case, training a big network to solve T different tasks, k<<T per data point, is somewhat equivalent to T parallel training runs trying to learn a circuit for each individual task over a subdistribution. This works because any one of the runs has a solution with a low learning coefficient, so one task won’t be trying to use effective parameters that another task needs. In a sense, this would be showing how the low-hanging fruit prior works.
Main missing pieces:
I don’t have the proof yet. I think I basically see what to do to get the constructions, but I actually need to sit down and crunch through the error propagation terms to make sure they check out.
With the right optimization procedure, I think we should be able to get the parameter vectors corresponding to the T individual circuits back out of the network. Apollo’s interp team is playing with a setup right now that I think might be able to do this. But it’s early days. We’re just calibrating on small toy models at the moment.
Spotted just now. At a glance, this still seems to be about boolean computation though. So I think I should still write up the construction I have in mind.
Status on the proof: I think it basically checks out for residual MLPs. Hoping to get an early draft of that done today. This will still be pretty hacky in places, and definitely not well presented. Depending on how much time I end up having and how many people collaborate with me, we might finish a writeup for transformers in the next two weeks.
I think we may be close to figuring out a general mathematical framework for circuits in superposition.
I suspect that we can get a proof that roughly shows:
If we have a set of T different transformers, with parameter counts N1,…NT implementing e.g. solutions to T different tasks
And those transformers are robust to size ϵ noise vectors being applied to the activations at their hidden layers
Then we can make a single transformer with N=O(∑Tt=1Nt) total parameters that can do all T tasks, provided any given input only asks for k<<T tasks to be carried out
Crucially, the total number of superposed operations we can carry out scales linearly with the network’s parameter count, not its neuron count or attention head count. E.g. if each little subnetwork uses n neurons per MLP layer and m dimensions in the residual stream, a big network with d1 neurons per MLP connected to a d0-dimensional residual stream can implement about O(d0d1mn) subnetworks, not just O(d1n).
This would be a generalization of the construction for boolean logic gates in superposition. It’d use the same central trick, but show that it can be applied to any set of operations or circuits, not just boolean logic gates. For example, you could superpose an MNIST image classifier network and a modular addition network with this.
So, we don’t just have superposed variables in the residual stream. The computations performed on those variables are also carried out in superposition.
Remarks:
What the subnetworks are doing doesn’t have to line up much with the components and layers of the big network. Things can be implemented all over the place. A single MLP and attention layer in a subnetwork could be implemented by a mishmash of many neurons and attention heads across a bunch of layers of the big network. Call it cross-layer superposition if you like.
This framing doesn’t really assume that the individual subnetworks are using one-dimensional ‘features’ represented as directions in activation space. The individual subnetworks can be doing basically anything they like in any way they like. They just have to be somewhat robust to noise in their hidden activations.
You could generalize this from T subnetworks doing unrelated tasks to T “circuits” each implementing some part of a big master computation. The crucial requirement is that only k<<T circuits are used on any one forward pass.
I think formulating this for transformers, MLPs and CNNs should be relatively straightforward. It’s all pretty much the same trick. I haven’t thought about e.g. Mamba yet.
Implications if we buy that real models work somewhat like this toy model would:
There is no superposition in parameter space. A network can’t have more independent operations than parameters. Every operation we want the network to implement takes some bits of description length in its parameters to specify, so the total description length scales linearly with the number of distinct operations. Overcomplete bases are only a thing in activation space.
There is a set of T Cartesian directions in the loss landscape that parametrize the T individual superposed circuits.
If the circuits don’t interact with each other, I think the learning coefficient of the whole network might roughly equal the sum of the learning coefficients of the individual circuits?
If that’s the case, training a big network to solve T different tasks, k<<T per data point, is somewhat equivalent to T parallel training runs trying to learn a circuit for each individual task over a subdistribution. This works because any one of the runs has a solution with a low learning coefficient, so one task won’t be trying to use effective parameters that another task needs. In a sense, this would be showing how the low-hanging fruit prior works.
Main missing pieces:
I don’t have the proof yet. I think I basically see what to do to get the constructions, but I actually need to sit down and crunch through the error propagation terms to make sure they check out.
With the right optimization procedure, I think we should be able to get the parameter vectors corresponding to the T individual circuits back out of the network. Apollo’s interp team is playing with a setup right now that I think might be able to do this. But it’s early days. We’re just calibrating on small toy models at the moment.
Spotted just now. At a glance, this still seems to be about boolean computation though. So I think I should still write up the construction I have in mind.
Status on the proof: I think it basically checks out for residual MLPs. Hoping to get an early draft of that done today. This will still be pretty hacky in places, and definitely not well presented. Depending on how much time I end up having and how many people collaborate with me, we might finish a writeup for transformers in the next two weeks.