What’s cool: 1. e2e saes learn very different features every seed. I’m glad y’all checked! This seems bad. 2. e2e SAEs have worse intermediate reconstruction loss than local. I would’ve predicted the opposite actually. 3. e2e+downstream seems to get all the benefits of the e2e one (same perf at lower L0) at the same compute cost, w/o the “intermediate activations aren’t similar” problem.
It looks like you’ve left for future work postraining SAE_local on KL or downstream loss as future work, but that’s a very interesting part! Specifically the approximation of SAE_e2e+downstream as you train on number of tokens.
Did y’all try ablations on SAE_e2e+downstream? For example, only training on the next layers Reconstruction loss or next N-layers rec loss?
2. Unlike local SAEs, our e2e SAEs aren’t trained on reconstructing the current layer’s activations. So at least my expectation was that they would get a worse reconstruction error at the current layer.
Improving training times wasn’t our focus for this paper, but I agree it would be interesting and expect there to be big gains to be made by doing things like mixing training between local and e2e+downstream and/or training multiple SAEs at once (depending on how you do this, you may need to be more careful about taking different pathways of computation to the original network).
We didn’t iterate on the e2e+downstream setup much. I think it’s very likely that you could get similar performance by making tweaks like the ones you suggested.
What a cool paper! Congrats!:)
What’s cool:
1. e2e saes learn very different features every seed. I’m glad y’all checked! This seems bad.
2. e2e SAEs have worse intermediate reconstruction loss than local. I would’ve predicted the opposite actually.
3. e2e+downstream seems to get all the benefits of the e2e one (same perf at lower L0) at the same compute cost, w/o the “intermediate activations aren’t similar” problem.
It looks like you’ve left for future work postraining SAE_local on KL or downstream loss as future work, but that’s a very interesting part! Specifically the approximation of SAE_e2e+downstream as you train on number of tokens.
Did y’all try ablations on SAE_e2e+downstream? For example, only training on the next layers Reconstruction loss or next N-layers rec loss?
Thanks Logan!
2. Unlike local SAEs, our e2e SAEs aren’t trained on reconstructing the current layer’s activations. So at least my expectation was that they would get a worse reconstruction error at the current layer.
Improving training times wasn’t our focus for this paper, but I agree it would be interesting and expect there to be big gains to be made by doing things like mixing training between local and e2e+downstream and/or training multiple SAEs at once (depending on how you do this, you may need to be more careful about taking different pathways of computation to the original network).
We didn’t iterate on the e2e+downstream setup much. I think it’s very likely that you could get similar performance by making tweaks like the ones you suggested.