Great work! Obviously the results here speak for themselves, but I especially wanted to complement the authors on the writing. I thought this paper was a pleasure to read, and easily a top 5% exemplar of clear technical writing. Thanks for putting in the effort on that.
I’ll post a few questions as children to this comment.
I believe that equation (10) giving the analytical solution to the optimization problem defining the relative reconstruction bias is incorrect. I believe the correct expression should be γ=Ex∼D[^x⋅x∥x∥22].
You could compute this by differentiating equation (9), setting it equal to 0 and solving for γ. But here’s a more geometrical argument.
By definition, γx is the multiple of x closest to ^x. Equivalently, this closest such vector can be described as the projection projx(^x)=^x⋅x∥x∥22x. Setting these equal, we get the claimed expression for γ.
As a sanity check, when our vectors are 1-dimensional, x=1, and ^x=12, we my expression gives γ=12 (which is correct), but equation (10) in the paper gives 1√3.
Oh oops, thanks so much. We’ll update the paper accordingly. Nit: it’s actually
Ex∼D[˙x⋅x]Ex∼D[||x||22]
(it’s just minimizing a quadratic)
ETA: the reason we have complicated equations is that we didn’t compute Ex∼D[˙x⋅x] during training (this quantity is kinda weird). However, you can compute γ from quantities that are usually tracked in SAE training. Specifically, γ=12(1+E[||xˆ||22]−E[||x−xˆ||22]E[||x||22]) and all terms here are clearly helpful to track in SAE training.
Oh, one other issue relating to this: in the paper it’s claimed that if γ is the argmin of E[∥^x−γ′x∥22] then 1/γ is the argmin of E[∥γ′^x−x∥]. However, this is not actually true: the argmin of the latter expression is E[x⋅^x]E[∥^x∥22]≠(E[x⋅^x]E[∥x∥2])−1. To get an intuition here, consider the case where ^x and x are very nearly perpendicular, with the angle between them just slightly less than 90∘. Then you should be able to convince yourself that the best factor to scale either x or ^x by in order to minimize the distance to the other will be just slightly greater than 0. Thus the optimal scaling factors cannot be reciprocals of each other.
ETA: Thinking on this a bit more, this might actually reflect a general issue with the way we think about feature shrinkage; namely, that whenever there is a nonzero angle between two vectors of the same length, the best way to make either vector close to the other will be by shrinking it. I’ll need to think about whether this makes me less convinced that the usual measures of feature shrinkage are capturing a real thing.
ETA2: In fact, now I’m a bit confused why your figure 6 shows no shrinkage. Based on what I wrote above in this comment, we should generally expect to see shrinkage (according to the definition given in equation (9)) whenever the autoencoder isn’t perfect. I guess the answer must somehow be “equation (10) actually is a good measure of shrinkage, in fact a better measure of shrinkage than the ‘corrected’ version of equation (10).” That’s pretty cool and surprising, because I don’t really have a great intuition for what equation (10) is actually capturing.
Thinking on this a bit more, this might actually reflect a general issue with the way we think about feature shrinkage; namely, that whenever there is a nonzero angle between two vectors of the same length, the best way to make either vector close to the other will be by shrinking it.
This was actually the key motivation for building this metric in the first place, instead of just looking at the ratio E[||^x||2]E[||x||2]. Looking at the γ that would optimize the reconstruction loss ensures that we’re capturing only bias from the L1 regularization, and not capturing the “inherent” need to shrink the vector given these nonzero angles. (In particular, if we computed E[||^x||2]E[||x||2] for Gated SAEs, I expect that would be below 1.)
I think the main thing we got wrong is that we accidentally treated E[||^x−x||2] as though it were E[||^x−γx||2]. To the extent that was the main mistake, I think it explains why our results still look how we expected them to—usually γ is going to be close to 1 (and should be almost exactly 1 if shrinkage is solved), so in practice the error introduced from this mistake is going to be extremely small.
We’re going to take a closer look at this tomorrow, check everything more carefully, and post an update after doing that. I think it’s probably worth waiting for that—I expect we’ll provide much more detailed derivations that make everything a lot clearer.
Hey Sam, thanks—you’re right. The definition of reconstruction bias is actually the argmin of
E[|^x/γ′−x|2]
which I’d (incorrectly) rearranged as the expression in the paper. As a result, the optimum is
γ−1=E[^x⋅x]/E[|^x|2]
That being said, the derivation we gave was not quite right, as I’d incorrectly substituted the optimised loss rather than the original reconstruction loss, which makes equation (10) incorrect. However the difference between the two is small exactly when gamma is close to one (and indeed vanishes when there is no shrinkage), which is probably why we didn’t pick this up. Anyway, we plan to correct these two equations and update the graphs, and will submit a revised version.
UPDATE: we’ve corrected equations 9 and 10 in the paper (screenshot of the draft below) and also added a footnote that hopefully helps clarify the derivation. I’ve also attached a revised figure 6, showing that this doesn’t change the overall story (for the mathematical reasons I mentioned in my previous comment). These will go up on arXiv, along with some other minor changes (like remembering to mention SAEs’ widths), likely some point next week. Thanks again Sam for pointing this out!
Updated equations (draft):
Updated figure 6 (shrinkage comparison for GELU-1L):
I’m a bit perplexed by the choice of loss function for training GSAEs (given by equation (8) in the paper). The intuitive (to me) thing to do here would be would be to have the Lreconstruct and Lsparsity terms, but not the Laux term, since the point of πgate is to tell you which features should be active, not to itself provide good feature coefficients for reconstructing x. I can sort of see how not including this term might result in the coordinates of πgate all being extremely small (but barely positive when it’s appropriate to use a feature), such that the sparsity term doesn’t contribute much to the loss. Is that what goes wrong? Are there ablation experiments you can report for this? If so, including this Laux term still currently seems to me like a pretty unprincipled way to deal with this—can the authors provide any flavor here?
Here are two ways that I’ve come up with for thinking about this loss function—let me know if either of these are on the right track. Let fgate,ReLU denote the gated encoder, but with a ReLU activation instead of Heaviside. Note then that fgate,ReLU is just the standard SAE encoder from Towards Monosemanticity.
Perspective 1: The usual loss from Towards Monosemanticity for training SAEs is ∥x−^x(fgate,ReLU(x))∥22+λ∥fgate,ReLU(x)∥1 (this is the same as your Lsparsity and Laux up to the detaching thing). But now you have this magnitude network which needs to get a gradient signal. Let’s do that by adding an additional term ∥x−^x(~f(x))∥22 -- your Lreconstruction. So under this perspective, it’s the reconstruction term which is new, with the sparsity and auxiliary terms being carried over from the usual way of doing things.
Perspective 2 (h/t Jannik Brinkmann): let’s just add together the usual Towards Monosemanticity loss function for both the usual architecture and the new modified archiecture: L=Lreconstruction(~f)+Lreconstruction(~f)+Lsparsity(fgate,ReLU)+Lsparsity(fgate,ReLU).
However, the gradients with respect to the second term in this sum vanish because of the use of the Heaviside, so the gradient with respect to this loss is the same as the gradient with respect to the loss you actually used.
Possibly I’m missing something, but if you don’t have Laux, then the only gradients to Wgate and bgate come from Lsparsity (the binarizing Heaviside activation function kills gradients from Lreconstruct), and so πgate would be always non-positive to get perfect zero sparsity loss. (That is, if you only optimize for L1 sparsity, the obvious solution is “none of the features are active”.)
(You could use a smooth activation function as the gate, e.g. an element-wise sigmoid, and then you could just stick with Lincorrect from the beginning of Section 3.2.2.)
Ah thanks, you’re totally right—that mostly resolves my confusion. I’m still a little bit dissatisfied, though, because the Laux term is optimizing for something that we don’t especially want (i.e. for ^x(ReLU(πgated(x)) to do a good job of reconstructing x). But I do see how you do need to have some sort of a reconstruction-esque term that actually allows gradients to pass through to the gated network.
Yep, the intuition here indeed was that L1 penalised reconstruction seems to be okay for teaching a standard SAE’s encoder to detect which features are on (even if features get shrunk as a result), so that is effectively what this auxiliary loss is teaching the gate sub-layer to do, alongside the sparsity penalty. (The key difference being we freeze the decoder in the auxiliary task, which the ablation study shows helps performance.) Maybe to put it another way, this was an auxiliary task that we had good evidence would teach the gate sublayer to detect active features reasonably well, and it turned out to give good results in practice. It’s totally possible though that there are better auxiliary tasks (or even completely different loss functions) out there that we’ve not explored.
(The question in this comment is more narrow and probably not interesting to most people.)
The limitations section includes this paragraph:
One worry about increasing the expressivity of sparse autoencoders is that they will overfit when reconstructing activations (Olah et al., 2023, Dictionary Learning Worries), since the underlying model only uses simple MLPs and attention heads, and in particular lacks discontinuities such as step functions. Overall we do not see evidence for this. Our evaluations use held-out test data and we check for interpretability manually. But these evaluations are not totally comprehensive: for example, they do not test that the dictionaries learned contain causally meaningful intermediate variables in the model’s computation. The discontinuity in particular introduces issues with methods like integrated gradients (Sundararajan et al., 2017) that discretely approximate a path integral, as applied to SAEs by Marks et al. (2024).
I’m not sure I understand the point about integrated gradients here. I understand this sentence as meaning: since model outputs are a discontinuous function of feature activations, integrated gradients will do a bad job of estimating the effect of patching feature activations to counterfactual values.
If that interpretation is correct, then I guess I’m confused because I think IG actually handles this sort of thing pretty gracefully. As long as the number of intermediate points you’re using is large enough that you’re sampling points pretty close to the discontinuity on both sides, then your error won’t be too large. This is in contrast to attribution patching which will have a pretty rough time here (but not really that much worse than with the normal ReLU encoders, I guess). (And maybe you also meant for this point to apply to attribution patching?)
I haven’t fully worked through the maths, but I think both IG and attribution patching break down here? The fundamental problem is that the discontinuity is invisible to IG because it only takes derivatives. Eg the ReLU and Jump ReLU below look identical from the perspective of IG, but not from the perspective of activation patching, I think.
Great work! Obviously the results here speak for themselves, but I especially wanted to complement the authors on the writing. I thought this paper was a pleasure to read, and easily a top 5% exemplar of clear technical writing. Thanks for putting in the effort on that.
<3 Thanks so much, that’s extremely kind. Credit entirely goes to Sen and Arthur, which is even more impressive given that they somehow took this from a blog post to a paper in a two week sprint! (including re-running all the experiments!!)
Great work! Obviously the results here speak for themselves, but I especially wanted to complement the authors on the writing. I thought this paper was a pleasure to read, and easily a top 5% exemplar of clear technical writing. Thanks for putting in the effort on that.
I’ll post a few questions as children to this comment.
I believe that equation (10) giving the analytical solution to the optimization problem defining the relative reconstruction bias is incorrect. I believe the correct expression should be γ=Ex∼D[^x⋅x∥x∥22].
You could compute this by differentiating equation (9), setting it equal to 0 and solving for γ. But here’s a more geometrical argument.
By definition, γx is the multiple of x closest to ^x. Equivalently, this closest such vector can be described as the projection projx(^x)=^x⋅x∥x∥22x. Setting these equal, we get the claimed expression for γ.
As a sanity check, when our vectors are 1-dimensional, x=1, and ^x=12, we my expression gives γ=12 (which is correct), but equation (10) in the paper gives 1√3.
Oh oops, thanks so much. We’ll update the paper accordingly. Nit: it’s actually
Ex∼D[˙x⋅x]Ex∼D[||x||22]
(it’s just minimizing a quadratic)
ETA: the reason we have complicated equations is that we didn’t compute Ex∼D[˙x⋅x] during training (this quantity is kinda weird). However, you can compute γ from quantities that are usually tracked in SAE training. Specifically, γ=12(1+E[||xˆ||22]−E[||x−xˆ||22]E[||x||22]) and all terms here are clearly helpful to track in SAE training.
Oh, one other issue relating to this: in the paper it’s claimed that if γ is the argmin of E[∥^x−γ′x∥22] then 1/γ is the argmin of E[∥γ′^x−x∥]. However, this is not actually true: the argmin of the latter expression is E[x⋅^x]E[∥^x∥22]≠(E[x⋅^x]E[∥x∥2])−1. To get an intuition here, consider the case where ^x and x are very nearly perpendicular, with the angle between them just slightly less than 90∘. Then you should be able to convince yourself that the best factor to scale either x or ^x by in order to minimize the distance to the other will be just slightly greater than 0. Thus the optimal scaling factors cannot be reciprocals of each other.
ETA: Thinking on this a bit more, this might actually reflect a general issue with the way we think about feature shrinkage; namely, that whenever there is a nonzero angle between two vectors of the same length, the best way to make either vector close to the other will be by shrinking it. I’ll need to think about whether this makes me less convinced that the usual measures of feature shrinkage are capturing a real thing.
ETA2: In fact, now I’m a bit confused why your figure 6 shows no shrinkage. Based on what I wrote above in this comment, we should generally expect to see shrinkage (according to the definition given in equation (9)) whenever the autoencoder isn’t perfect. I guess the answer must somehow be “equation (10) actually is a good measure of shrinkage, in fact a better measure of shrinkage than the ‘corrected’ version of equation (10).” That’s pretty cool and surprising, because I don’t really have a great intuition for what equation (10) is actually capturing.
This was actually the key motivation for building this metric in the first place, instead of just looking at the ratio E[||^x||2]E[||x||2]. Looking at the γ that would optimize the reconstruction loss ensures that we’re capturing only bias from the L1 regularization, and not capturing the “inherent” need to shrink the vector given these nonzero angles. (In particular, if we computed E[||^x||2]E[||x||2] for Gated SAEs, I expect that would be below 1.)
I think the main thing we got wrong is that we accidentally treated E[||^x−x||2] as though it were E[||^x−γx||2]. To the extent that was the main mistake, I think it explains why our results still look how we expected them to—usually γ is going to be close to 1 (and should be almost exactly 1 if shrinkage is solved), so in practice the error introduced from this mistake is going to be extremely small.
We’re going to take a closer look at this tomorrow, check everything more carefully, and post an update after doing that. I think it’s probably worth waiting for that—I expect we’ll provide much more detailed derivations that make everything a lot clearer.
Hey Sam, thanks—you’re right. The definition of reconstruction bias is actually the argmin of
E[|^x/γ′−x|2]
which I’d (incorrectly) rearranged as the expression in the paper. As a result, the optimum is
γ−1=E[^x⋅x]/E[|^x|2]
That being said, the derivation we gave was not quite right, as I’d incorrectly substituted the optimised loss rather than the original reconstruction loss, which makes equation (10) incorrect. However the difference between the two is small exactly when gamma is close to one (and indeed vanishes when there is no shrinkage), which is probably why we didn’t pick this up. Anyway, we plan to correct these two equations and update the graphs, and will submit a revised version.
UPDATE: we’ve corrected equations 9 and 10 in the paper (screenshot of the draft below) and also added a footnote that hopefully helps clarify the derivation. I’ve also attached a revised figure 6, showing that this doesn’t change the overall story (for the mathematical reasons I mentioned in my previous comment). These will go up on arXiv, along with some other minor changes (like remembering to mention SAEs’ widths), likely some point next week. Thanks again Sam for pointing this out!
Updated equations (draft):
Updated figure 6 (shrinkage comparison for GELU-1L):
I’m a bit perplexed by the choice of loss function for training GSAEs (given by equation (8) in the paper). The intuitive (to me) thing to do here would be would be to have the Lreconstruct and Lsparsity terms, but not the Laux term, since the point of πgate is to tell you which features should be active, not to itself provide good feature coefficients for reconstructing x. I can sort of see how not including this term might result in the coordinates of πgate all being extremely small (but barely positive when it’s appropriate to use a feature), such that the sparsity term doesn’t contribute much to the loss. Is that what goes wrong? Are there ablation experiments you can report for this? If so, including this Laux term still currently seems to me like a pretty unprincipled way to deal with this—can the authors provide any flavor here?
Here are two ways that I’ve come up with for thinking about this loss function—let me know if either of these are on the right track. Let fgate,ReLU denote the gated encoder, but with a ReLU activation instead of Heaviside. Note then that fgate,ReLU is just the standard SAE encoder from Towards Monosemanticity.
Perspective 1: The usual loss from Towards Monosemanticity for training SAEs is ∥x−^x(fgate,ReLU(x))∥22+λ∥fgate,ReLU(x)∥1 (this is the same as your Lsparsity and Laux up to the detaching thing). But now you have this magnitude network which needs to get a gradient signal. Let’s do that by adding an additional term ∥x−^x(~f(x))∥22 -- your Lreconstruction. So under this perspective, it’s the reconstruction term which is new, with the sparsity and auxiliary terms being carried over from the usual way of doing things.
Perspective 2 (h/t Jannik Brinkmann): let’s just add together the usual Towards Monosemanticity loss function for both the usual architecture and the new modified archiecture: L=Lreconstruction(~f)+Lreconstruction(~f)+Lsparsity(fgate,ReLU)+Lsparsity(fgate,ReLU).
However, the gradients with respect to the second term in this sum vanish because of the use of the Heaviside, so the gradient with respect to this loss is the same as the gradient with respect to the loss you actually used.
Possibly I’m missing something, but if you don’t have Laux, then the only gradients to Wgate and bgate come from Lsparsity (the binarizing Heaviside activation function kills gradients from Lreconstruct), and so πgate would be always non-positive to get perfect zero sparsity loss. (That is, if you only optimize for L1 sparsity, the obvious solution is “none of the features are active”.)
(You could use a smooth activation function as the gate, e.g. an element-wise sigmoid, and then you could just stick with Lincorrect from the beginning of Section 3.2.2.)
Ah thanks, you’re totally right—that mostly resolves my confusion. I’m still a little bit dissatisfied, though, because the Laux term is optimizing for something that we don’t especially want (i.e. for ^x(ReLU(πgated(x)) to do a good job of reconstructing x). But I do see how you do need to have some sort of a reconstruction-esque term that actually allows gradients to pass through to the gated network.
Yep, the intuition here indeed was that L1 penalised reconstruction seems to be okay for teaching a standard SAE’s encoder to detect which features are on (even if features get shrunk as a result), so that is effectively what this auxiliary loss is teaching the gate sub-layer to do, alongside the sparsity penalty. (The key difference being we freeze the decoder in the auxiliary task, which the ablation study shows helps performance.) Maybe to put it another way, this was an auxiliary task that we had good evidence would teach the gate sublayer to detect active features reasonably well, and it turned out to give good results in practice. It’s totally possible though that there are better auxiliary tasks (or even completely different loss functions) out there that we’ve not explored.
(The question in this comment is more narrow and probably not interesting to most people.)
The limitations section includes this paragraph:
I’m not sure I understand the point about integrated gradients here. I understand this sentence as meaning: since model outputs are a discontinuous function of feature activations, integrated gradients will do a bad job of estimating the effect of patching feature activations to counterfactual values.
If that interpretation is correct, then I guess I’m confused because I think IG actually handles this sort of thing pretty gracefully. As long as the number of intermediate points you’re using is large enough that you’re sampling points pretty close to the discontinuity on both sides, then your error won’t be too large. This is in contrast to attribution patching which will have a pretty rough time here (but not really that much worse than with the normal ReLU encoders, I guess). (And maybe you also meant for this point to apply to attribution patching?)
I haven’t fully worked through the maths, but I think both IG and attribution patching break down here? The fundamental problem is that the discontinuity is invisible to IG because it only takes derivatives. Eg the ReLU and Jump ReLU below look identical from the perspective of IG, but not from the perspective of activation patching, I think.
Yep, you’re totally right—thanks!
<3 Thanks so much, that’s extremely kind. Credit entirely goes to Sen and Arthur, which is even more impressive given that they somehow took this from a blog post to a paper in a two week sprint! (including re-running all the experiments!!)