Has anyone tried training an SAE using the performance of the patched model as the loss function? I guess this would be a lot more expensive, but given that is the metric we actually care about, it seems sensible to optimise for it directly.
This is a good idea and is something we’re (Apollo + MATS stream) working on atm. We’re planning on releasing our agenda related to this and, of course, results whenever they’re ready to share.
I’ve heard this idea floated a few times and am a little worried that “When a measure becomes a target, it ceases to be a good measure” will apply here. OTOH, you can directly check whether the MSE / variance explained diverges significantly so at least you can track the resulting SAE’s use for decomposition. I’d be pretty surprised if an SAE trained with this objective became vastly more performant and you could check whether downstream activations of the reconstructed activations were off distribution. So overall, I’m pretty excited to see what you get!
Has anyone tried training an SAE using the performance of the patched model as the loss function? I guess this would be a lot more expensive, but given that is the metric we actually care about, it seems sensible to optimise for it directly.
This is a good idea and is something we’re (Apollo + MATS stream) working on atm. We’re planning on releasing our agenda related to this and, of course, results whenever they’re ready to share.
I’ve heard this idea floated a few times and am a little worried that “When a measure becomes a target, it ceases to be a good measure” will apply here. OTOH, you can directly check whether the MSE / variance explained diverges significantly so at least you can track the resulting SAE’s use for decomposition. I’d be pretty surprised if an SAE trained with this objective became vastly more performant and you could check whether downstream activations of the reconstructed activations were off distribution. So overall, I’m pretty excited to see what you get!