Did you run an ablation on the auxiliary losses for Qclean⋅Krecon and Qrecon⋅Kclean , how important was that to stabilize training?
Did you compare to training separate Q and K SAEs via typical reconstruction loss? Would be cool to see a side-by-side comparison, i.e. how large the benefit of this scheme is.
The auxiliary losses were something we settled on quite early, and we made some improvements to the methodology since then for the current results so I don’t have great apples-to-apples comparisons for you. The losses didn’t seem super important though in the sense that runs would still converge, just take longer and end with slightly worse reconstruction error. I think it’s very likely that with a better training set-up/better hyperparam tuning you could drop these entirely and be fine.
Re: comparison to SAE’s, you mean what do the dictionaries/feature-map have to look like if you’re explicitly targeting L2-reconstruction error and just getting pattern reconstruction as a side-effect? If so we also looked at this briefly early on. We didn’t spend a huge amount of time on these so they were probably not optimally trained, but we were finding that to get L2-reconstruction error low enough to yield comparably close good pattern reconstruction we were needing to go up to a d_hidden of 16,000 i.e. comparable to residual SAEs for the same layer. Which I think is another data-point in favour of “a lot of the variance in head-space is attention-irrelevant and just inherited from the residual stream”
Cool work!
Did you run an ablation on the auxiliary losses for Qclean⋅Krecon and Qrecon⋅Kclean , how important was that to stabilize training?
Did you compare to training separate Q and K SAEs via typical reconstruction loss? Would be cool to see a side-by-side comparison, i.e. how large the benefit of this scheme is.
Thanks!
The auxiliary losses were something we settled on quite early, and we made some improvements to the methodology since then for the current results so I don’t have great apples-to-apples comparisons for you. The losses didn’t seem super important though in the sense that runs would still converge, just take longer and end with slightly worse reconstruction error. I think it’s very likely that with a better training set-up/better hyperparam tuning you could drop these entirely and be fine.
Re: comparison to SAE’s, you mean what do the dictionaries/feature-map have to look like if you’re explicitly targeting L2-reconstruction error and just getting pattern reconstruction as a side-effect? If so we also looked at this briefly early on. We didn’t spend a huge amount of time on these so they were probably not optimally trained, but we were finding that to get L2-reconstruction error low enough to yield comparably close good pattern reconstruction we were needing to go up to a d_hidden of 16,000 i.e. comparable to residual SAEs for the same layer. Which I think is another data-point in favour of “a lot of the variance in head-space is attention-irrelevant and just inherited from the residual stream”