Cool post! I often find myself confused/unable to guess why people I don’t know are excited about SAEs (there seem to be a few vaguely conflicting reasons), and this was a very clear description of your agenda.
I’m a little confused by this point:
> The reconstruction loss trains the SAE features to approximate what the network does, thus optimizing for mathematical description accuracy
It’s not clear to me that framing reconstruction loss as ‘approximating what the network does’ is the correct framing of this loss. In my mind, the reconstruction loss is more of a non-degeneracy control to encourage almost-orthogonality between features; In toy settings, SAEs are able to recover ground truth directions while still having sub-perfect reconstruction loss, and it seems very plausible that we should be able to use this (e.g. maybe through gradient-based attribution) without having to optimise heavily for reconstruction loss, which might degrade scalability (which seems very important for this agenda) and monosemanticity compared to currently-unexplored alternatives.
In my mind, the reconstruction loss is more of a non-degeneracy control to encourage almost-orthogonality between features.
I don’t currently see why reconstruction would encourage features to be different directions from each other in any way unless paired with an L_{0<p<1}. And I specifically don’t mean L1, because in toy data settings with recon+L1, you can end up with features pointing in exactly the same direction.
Cool post! I often find myself confused/unable to guess why people I don’t know are excited about SAEs (there seem to be a few vaguely conflicting reasons), and this was a very clear description of your agenda.
I’m a little confused by this point:
> The reconstruction loss trains the SAE features to approximate what the network does, thus optimizing for mathematical description accuracy
It’s not clear to me that framing reconstruction loss as ‘approximating what the network does’ is the correct framing of this loss. In my mind, the reconstruction loss is more of a non-degeneracy control to encourage almost-orthogonality between features; In toy settings, SAEs are able to recover ground truth directions while still having sub-perfect reconstruction loss, and it seems very plausible that we should be able to use this (e.g. maybe through gradient-based attribution) without having to optimise heavily for reconstruction loss, which might degrade scalability (which seems very important for this agenda) and monosemanticity compared to currently-unexplored alternatives.
Thanks Aidan!
I’m not sure I follow this bit:
I don’t currently see why reconstruction would encourage features to be different directions from each other in any way unless paired with an L_{0<p<1}. And I specifically don’t mean L1, because in toy data settings with recon+L1, you can end up with features pointing in exactly the same direction.