You can’t do JumpReLU → ReLU in the current setup, as there would be no threshold parameters to train with the STE (i.e. it would basically train a ReLU SAE with no sparsity penalty). In principle you should be able to train the other SAE parameters by adding more pseudo derivatives, but we found this didn’t work as well as just training the threshold (so there was then no point in trying this ablation). L0 → L1 leads to worse Pareto curves (I can’t remember off the top of my head which side of Gated, but definitely not significantly better than Gated) - it’s a good question whether this resolves the high frequency features; my guess is it would (as I think Gated SAEs basically approximate JumpReLU SAEs with a L1 loss) but we didn’t check this.
Senthooran Rajamanoharan
SAE Probing: What is it good for? Absolutely something!
Thanks! I was worried about how well KDE would work because of sparsity, but in practice it seems to work quite well. My intuition is that the penalty only matters for more frequent features (i.e. features that typically do fire, perhaps multiple times, within a 4096 batch), plus the threshold is only updated a small amount at each step (so the noise averages out reliably over many steps). In any case, we tried increasing batch size and turning on momentum (with the idea this would reduce noise), but neither seemed to particularly help performance (keeping the number of tokens constant). One thing I suspect would further improve things is to have a per-feature bandwidth parameter, as I imagine the constant bandwidth we use is too narrow for some features and too wide for others; but to do this we also need to find a reliable way to adapt the bandwidth during training.
JumpReLU SAEs + Early Access to Gemma 2 SAEs
We found that exactly that form of sparsity penalty did improve shrinkage with standard (ungated) SAEs, and provide a decent boost to loss recovered at low L0. (We didn’t evaluate interpretability though.) But then we hit upon Gated SAEs which looked even better, and for which modifying the sparsity penalty in this way feels less necessary, so we haven’t experimented with combining the two.
Good question—we’re planning to post an update on this point about combining the new sparsity penalty from Anthropic with Gated SAEs. The TL;DR is that you can replace the L1 term in the Gated SAE loss with the analogous (gated feature magnitudes dotted with decoder magnitudes) sparsity term introduced by Anthropic and thereby do away with the decoder norms constraint and resampling. If you’re going to do this, you also need to either unfreeze the decoder in the auxiliary task, or freeze the decoder weights where they appear in the sparsity penalty; both attain reasonably similar performance, and are definitely better than having the decoder weights frozen in one place but not the other. Put together, this seems to a marginal hit (versus the original Gated loss with L1 penalty and resampling) when comparing Pareto curves, but may be worth to the extent it simplifies training (with this loss function, the SAE training loop just becomes a vanilla neural network training loop).
PS With either the original (L1-based) loss or the modified loss of the previous paragraph, some of the other improvements suggested in the Anthropic post—in particular, initializing the encoder weights to the transpose of the decoder weights (only at initialisation, not tying them thereafter), and warming up lambda. My point about the new loss not being Pareto better than L1 applies only if you compare like with like—i.e. apply these other improvements in both cases.
On , it’s unclear what a “natural” choice would be for setting this parameter in order to simplify the architecture further. One natural reference point is to set it to , but this corresponds to getting rid of the discontinuity in the Jump ReLU (turning the magnitude encoder into a ReLU on multiplicatively rescaled gate encoder preactivations). Effectively (removing the now unnecessary auxiliary task), this would give results similar to the “baseline + rescale & shift” benchmark in section 5.2 of the paper, although probably worse, as we wouldn’t have the shift.
UPDATE: we’ve corrected equations 9 and 10 in the paper (screenshot of the draft below) and also added a footnote that hopefully helps clarify the derivation. I’ve also attached a revised figure 6, showing that this doesn’t change the overall story (for the mathematical reasons I mentioned in my previous comment). These will go up on arXiv, along with some other minor changes (like remembering to mention SAEs’ widths), likely some point next week. Thanks again Sam for pointing this out!
Updated equations (draft):
Updated figure 6 (shrinkage comparison for GELU-1L):
Yep, the intuition here indeed was that L1 penalised reconstruction seems to be okay for teaching a standard SAE’s encoder to detect which features are on (even if features get shrunk as a result), so that is effectively what this auxiliary loss is teaching the gate sub-layer to do, alongside the sparsity penalty. (The key difference being we freeze the decoder in the auxiliary task, which the ablation study shows helps performance.) Maybe to put it another way, this was an auxiliary task that we had good evidence would teach the gate sublayer to detect active features reasonably well, and it turned out to give good results in practice. It’s totally possible though that there are better auxiliary tasks (or even completely different loss functions) out there that we’ve not explored.
Hey Sam, thanks—you’re right. The definition of reconstruction bias is actually the argmin of
which I’d (incorrectly) rearranged as the expression in the paper. As a result, the optimum is
That being said, the derivation we gave was not quite right, as I’d incorrectly substituted the optimised loss rather than the original reconstruction loss, which makes equation (10) incorrect. However the difference between the two is small exactly when gamma is close to one (and indeed vanishes when there is no shrinkage), which is probably why we didn’t pick this up. Anyway, we plan to correct these two equations and update the graphs, and will submit a revised version.
Good question! The short answer is that we tried the simplest thing and it happened to work well. (The additional computational cost of higher order kernels is negligible.) We did look at other kernels but my intuition, borne out by the results (to be included in a v2 shortly, along with some typos and bugfixes to the pseudo-code), was that this would not make much difference. This is because we’re using KDE in a very high variance regime already, and yet the SAEs seem to train fine: given a batch size of 4096 and bandwidth of 0.001, hardly any activations (even for the most frequent features) end up having kernels that capture the threshold (i.e. that are included in the gradient estimate). So it’s not obvious that improving the bias-variance trade-off slightly by switching to a different kernel is going to make that much difference. In this sense, the way we use KDE here is very different from e.g. using KDE to visualise empirical distributions, and so we need to be careful about how we transfer intuitions between these domains.