Great paper! The gating approach is an interesting way to learn the JumpReLU threshold and it’s exciting that it works well. We’ve been working on some related directions at OpenAI based on similar intuitions about feature shrinking.
Some questions:
Is b_mag still necessary in the gated autoencoder?
Did you sweep learning rates for the baseline and your approach?
We use learning rate 0.0003 for all Gated SAE experiments, and also the GELU-1L baseline experiment. We swept for optimal baseline learning rates on GELU-1L for the baseline SAE to generate this value.
For the Pythia-2.8B and Gemma-7B baseline SAE experiments, we divided the L2 loss by E||x||2, motivated by wanting better hyperparameter transfer, and so changed learning rate to 0.001 or 0.00075 for all the runs (currently in Figure 1, only attention output pre-linear uses 0.00075. In the rerelease we’ll state all the values used). We didn’t see noticable difference in the Pareto frontier changing between 0.001 and 0.00075 so did not sweep the baseline hyperparameter further than this.
Re dictionary width, 2**17 (~131K) for most Gated SAEs, 3*(2**16) for baseline SAEs, except for the (Pythia-2.8B, Residual Stream) sites we used 2**15 for Gated and 3*(2**14) for baseline since early runs of these had lots of feature death. (This’ll be added to the paper soon, sorry!). I’ll leave the other Qs for my co-authors
Got it—do you think with a bit more tuning the feature death at larger scale could be eliminated, or would it be tough to manage with the reinitialization approach?
I’m not sure what you mean by “the reinitialization approach” but feature death doesn’t seem to be a major issue at the moment. At all sites besides L27, our Gemma-7B SAEs didn’t have much feature death at all (stats at https://arxiv.org/pdf/2404.16014v2 up in a few hours), and also the Anthropic update suggests even in small models the problem can be addressed.
Sorry I meant the Anthropiclike neuron resampling procedure.
I think I misread Neel’s comment, I thought he was saying that 131k was chosen because larger autoencoders would have too many dead latents (as opposed to this only being for Pythia residual).
Ah yeah, Neel’s comment makes no claims about feature death beyond Pythia 2.8B residual streams. I trained 524K width Pythia-2.8B MLP SAEs with <5% feature death (not in paper), and Anthropic’s work gets to >1M live features (with no claims about interpretability) which together would make me surprised if 131K was near the max of possible numbers of live features even in small models.
On bmag, it’s unclear what a “natural” choice would be for setting this parameter in order to simplify the architecture further. One natural reference point is to set it to ermag⊙bgate, but this corresponds to getting rid of the discontinuity in the Jump ReLU (turning the magnitude encoder into a ReLU on multiplicatively rescaled gate encoder preactivations). Effectively (removing the now unnecessary auxiliary task), this would give results similar to the “baseline + rescale & shift” benchmark in section 5.2 of the paper, although probably worse, as we wouldn’t have the shift.
Great paper! The gating approach is an interesting way to learn the JumpReLU threshold and it’s exciting that it works well. We’ve been working on some related directions at OpenAI based on similar intuitions about feature shrinking.
Some questions:
Is b_mag still necessary in the gated autoencoder?
Did you sweep learning rates for the baseline and your approach?
How large is the dictionary of the autoencoder?
We use learning rate 0.0003 for all Gated SAE experiments, and also the GELU-1L baseline experiment. We swept for optimal baseline learning rates on GELU-1L for the baseline SAE to generate this value.
For the Pythia-2.8B and Gemma-7B baseline SAE experiments, we divided the L2 loss by E||x||2, motivated by wanting better hyperparameter transfer, and so changed learning rate to 0.001 or 0.00075 for all the runs (currently in Figure 1, only attention output pre-linear uses 0.00075. In the rerelease we’ll state all the values used). We didn’t see noticable difference in the Pareto frontier changing between 0.001 and 0.00075 so did not sweep the baseline hyperparameter further than this.
Thanks, that makes sense
Re dictionary width, 2**17 (~131K) for most Gated SAEs, 3*(2**16) for baseline SAEs, except for the (Pythia-2.8B, Residual Stream) sites we used 2**15 for Gated and 3*(2**14) for baseline since early runs of these had lots of feature death. (This’ll be added to the paper soon, sorry!). I’ll leave the other Qs for my co-authors
Got it—do you think with a bit more tuning the feature death at larger scale could be eliminated, or would it be tough to manage with the reinitialization approach?
I’m not sure what you mean by “the reinitialization approach” but feature death doesn’t seem to be a major issue at the moment. At all sites besides L27, our Gemma-7B SAEs didn’t have much feature death at all (stats at https://arxiv.org/pdf/2404.16014v2 up in a few hours), and also the Anthropic update suggests even in small models the problem can be addressed.
Sorry I meant the Anthropiclike neuron resampling procedure.
I think I misread Neel’s comment, I thought he was saying that 131k was chosen because larger autoencoders would have too many dead latents (as opposed to this only being for Pythia residual).
Ah yeah, Neel’s comment makes no claims about feature death beyond Pythia 2.8B residual streams. I trained 524K width Pythia-2.8B MLP SAEs with <5% feature death (not in paper), and Anthropic’s work gets to >1M live features (with no claims about interpretability) which together would make me surprised if 131K was near the max of possible numbers of live features even in small models.
On bmag, it’s unclear what a “natural” choice would be for setting this parameter in order to simplify the architecture further. One natural reference point is to set it to ermag⊙bgate, but this corresponds to getting rid of the discontinuity in the Jump ReLU (turning the magnitude encoder into a ReLU on multiplicatively rescaled gate encoder preactivations). Effectively (removing the now unnecessary auxiliary task), this would give results similar to the “baseline + rescale & shift” benchmark in section 5.2 of the paper, although probably worse, as we wouldn’t have the shift.
Makes sense that the shift would be helpful