I believe that equation (10) giving the analytical solution to the optimization problem defining the relative reconstruction bias is incorrect. I believe the correct expression should be γ=Ex∼D[^x⋅x∥x∥22].
You could compute this by differentiating equation (9), setting it equal to 0 and solving for γ. But here’s a more geometrical argument.
By definition, γx is the multiple of x closest to ^x. Equivalently, this closest such vector can be described as the projection projx(^x)=^x⋅x∥x∥22x. Setting these equal, we get the claimed expression for γ.
As a sanity check, when our vectors are 1-dimensional, x=1, and ^x=12, we my expression gives γ=12 (which is correct), but equation (10) in the paper gives 1√3.
Oh oops, thanks so much. We’ll update the paper accordingly. Nit: it’s actually
Ex∼D[˙x⋅x]Ex∼D[||x||22]
(it’s just minimizing a quadratic)
ETA: the reason we have complicated equations is that we didn’t compute Ex∼D[˙x⋅x] during training (this quantity is kinda weird). However, you can compute γ from quantities that are usually tracked in SAE training. Specifically, γ=12(1+E[||xˆ||22]−E[||x−xˆ||22]E[||x||22]) and all terms here are clearly helpful to track in SAE training.
Oh, one other issue relating to this: in the paper it’s claimed that if γ is the argmin of E[∥^x−γ′x∥22] then 1/γ is the argmin of E[∥γ′^x−x∥]. However, this is not actually true: the argmin of the latter expression is E[x⋅^x]E[∥^x∥22]≠(E[x⋅^x]E[∥x∥2])−1. To get an intuition here, consider the case where ^x and x are very nearly perpendicular, with the angle between them just slightly less than 90∘. Then you should be able to convince yourself that the best factor to scale either x or ^x by in order to minimize the distance to the other will be just slightly greater than 0. Thus the optimal scaling factors cannot be reciprocals of each other.
ETA: Thinking on this a bit more, this might actually reflect a general issue with the way we think about feature shrinkage; namely, that whenever there is a nonzero angle between two vectors of the same length, the best way to make either vector close to the other will be by shrinking it. I’ll need to think about whether this makes me less convinced that the usual measures of feature shrinkage are capturing a real thing.
ETA2: In fact, now I’m a bit confused why your figure 6 shows no shrinkage. Based on what I wrote above in this comment, we should generally expect to see shrinkage (according to the definition given in equation (9)) whenever the autoencoder isn’t perfect. I guess the answer must somehow be “equation (10) actually is a good measure of shrinkage, in fact a better measure of shrinkage than the ‘corrected’ version of equation (10).” That’s pretty cool and surprising, because I don’t really have a great intuition for what equation (10) is actually capturing.
Thinking on this a bit more, this might actually reflect a general issue with the way we think about feature shrinkage; namely, that whenever there is a nonzero angle between two vectors of the same length, the best way to make either vector close to the other will be by shrinking it.
This was actually the key motivation for building this metric in the first place, instead of just looking at the ratio E[||^x||2]E[||x||2]. Looking at the γ that would optimize the reconstruction loss ensures that we’re capturing only bias from the L1 regularization, and not capturing the “inherent” need to shrink the vector given these nonzero angles. (In particular, if we computed E[||^x||2]E[||x||2] for Gated SAEs, I expect that would be below 1.)
I think the main thing we got wrong is that we accidentally treated E[||^x−x||2] as though it were E[||^x−γx||2]. To the extent that was the main mistake, I think it explains why our results still look how we expected them to—usually γ is going to be close to 1 (and should be almost exactly 1 if shrinkage is solved), so in practice the error introduced from this mistake is going to be extremely small.
We’re going to take a closer look at this tomorrow, check everything more carefully, and post an update after doing that. I think it’s probably worth waiting for that—I expect we’ll provide much more detailed derivations that make everything a lot clearer.
Hey Sam, thanks—you’re right. The definition of reconstruction bias is actually the argmin of
E[|^x/γ′−x|2]
which I’d (incorrectly) rearranged as the expression in the paper. As a result, the optimum is
γ−1=E[^x⋅x]/E[|^x|2]
That being said, the derivation we gave was not quite right, as I’d incorrectly substituted the optimised loss rather than the original reconstruction loss, which makes equation (10) incorrect. However the difference between the two is small exactly when gamma is close to one (and indeed vanishes when there is no shrinkage), which is probably why we didn’t pick this up. Anyway, we plan to correct these two equations and update the graphs, and will submit a revised version.
UPDATE: we’ve corrected equations 9 and 10 in the paper (screenshot of the draft below) and also added a footnote that hopefully helps clarify the derivation. I’ve also attached a revised figure 6, showing that this doesn’t change the overall story (for the mathematical reasons I mentioned in my previous comment). These will go up on arXiv, along with some other minor changes (like remembering to mention SAEs’ widths), likely some point next week. Thanks again Sam for pointing this out!
Updated equations (draft):
Updated figure 6 (shrinkage comparison for GELU-1L):
I believe that equation (10) giving the analytical solution to the optimization problem defining the relative reconstruction bias is incorrect. I believe the correct expression should be γ=Ex∼D[^x⋅x∥x∥22].
You could compute this by differentiating equation (9), setting it equal to 0 and solving for γ. But here’s a more geometrical argument.
By definition, γx is the multiple of x closest to ^x. Equivalently, this closest such vector can be described as the projection projx(^x)=^x⋅x∥x∥22x. Setting these equal, we get the claimed expression for γ.
As a sanity check, when our vectors are 1-dimensional, x=1, and ^x=12, we my expression gives γ=12 (which is correct), but equation (10) in the paper gives 1√3.
Oh oops, thanks so much. We’ll update the paper accordingly. Nit: it’s actually
Ex∼D[˙x⋅x]Ex∼D[||x||22]
(it’s just minimizing a quadratic)
ETA: the reason we have complicated equations is that we didn’t compute Ex∼D[˙x⋅x] during training (this quantity is kinda weird). However, you can compute γ from quantities that are usually tracked in SAE training. Specifically, γ=12(1+E[||xˆ||22]−E[||x−xˆ||22]E[||x||22]) and all terms here are clearly helpful to track in SAE training.
Oh, one other issue relating to this: in the paper it’s claimed that if γ is the argmin of E[∥^x−γ′x∥22] then 1/γ is the argmin of E[∥γ′^x−x∥]. However, this is not actually true: the argmin of the latter expression is E[x⋅^x]E[∥^x∥22]≠(E[x⋅^x]E[∥x∥2])−1. To get an intuition here, consider the case where ^x and x are very nearly perpendicular, with the angle between them just slightly less than 90∘. Then you should be able to convince yourself that the best factor to scale either x or ^x by in order to minimize the distance to the other will be just slightly greater than 0. Thus the optimal scaling factors cannot be reciprocals of each other.
ETA: Thinking on this a bit more, this might actually reflect a general issue with the way we think about feature shrinkage; namely, that whenever there is a nonzero angle between two vectors of the same length, the best way to make either vector close to the other will be by shrinking it. I’ll need to think about whether this makes me less convinced that the usual measures of feature shrinkage are capturing a real thing.
ETA2: In fact, now I’m a bit confused why your figure 6 shows no shrinkage. Based on what I wrote above in this comment, we should generally expect to see shrinkage (according to the definition given in equation (9)) whenever the autoencoder isn’t perfect. I guess the answer must somehow be “equation (10) actually is a good measure of shrinkage, in fact a better measure of shrinkage than the ‘corrected’ version of equation (10).” That’s pretty cool and surprising, because I don’t really have a great intuition for what equation (10) is actually capturing.
This was actually the key motivation for building this metric in the first place, instead of just looking at the ratio E[||^x||2]E[||x||2]. Looking at the γ that would optimize the reconstruction loss ensures that we’re capturing only bias from the L1 regularization, and not capturing the “inherent” need to shrink the vector given these nonzero angles. (In particular, if we computed E[||^x||2]E[||x||2] for Gated SAEs, I expect that would be below 1.)
I think the main thing we got wrong is that we accidentally treated E[||^x−x||2] as though it were E[||^x−γx||2]. To the extent that was the main mistake, I think it explains why our results still look how we expected them to—usually γ is going to be close to 1 (and should be almost exactly 1 if shrinkage is solved), so in practice the error introduced from this mistake is going to be extremely small.
We’re going to take a closer look at this tomorrow, check everything more carefully, and post an update after doing that. I think it’s probably worth waiting for that—I expect we’ll provide much more detailed derivations that make everything a lot clearer.
Hey Sam, thanks—you’re right. The definition of reconstruction bias is actually the argmin of
E[|^x/γ′−x|2]
which I’d (incorrectly) rearranged as the expression in the paper. As a result, the optimum is
γ−1=E[^x⋅x]/E[|^x|2]
That being said, the derivation we gave was not quite right, as I’d incorrectly substituted the optimised loss rather than the original reconstruction loss, which makes equation (10) incorrect. However the difference between the two is small exactly when gamma is close to one (and indeed vanishes when there is no shrinkage), which is probably why we didn’t pick this up. Anyway, we plan to correct these two equations and update the graphs, and will submit a revised version.
UPDATE: we’ve corrected equations 9 and 10 in the paper (screenshot of the draft below) and also added a footnote that hopefully helps clarify the derivation. I’ve also attached a revised figure 6, showing that this doesn’t change the overall story (for the mathematical reasons I mentioned in my previous comment). These will go up on arXiv, along with some other minor changes (like remembering to mention SAEs’ widths), likely some point next week. Thanks again Sam for pointing this out!
Updated equations (draft):
Updated figure 6 (shrinkage comparison for GELU-1L):