This is neat, nice work!
I’m finding it quite hard to get a sense at what the actual Loss Recovered numbers you report are, and to compare them concretely to other work. If possible, it’d be very helpful if you shared:
What the zero ablations CE scores are for each model and SAE position. (I assume it’s much worse for the MLP and attention outputs than the residual stream?)
What the baseline CE scores are for each model.
Thanks Logan!
2. Unlike local SAEs, our e2e SAEs aren’t trained on reconstructing the current layer’s activations. So at least my expectation was that they would get a worse reconstruction error at the current layer.
Improving training times wasn’t our focus for this paper, but I agree it would be interesting and expect there to be big gains to be made by doing things like mixing training between local and e2e+downstream and/or training multiple SAEs at once (depending on how you do this, you may need to be more careful about taking different pathways of computation to the original network).
We didn’t iterate on the e2e+downstream setup much. I think it’s very likely that you could get similar performance by making tweaks like the ones you suggested.