I’m a bit perplexed by the choice of loss function for training GSAEs (given by equation (8) in the paper). The intuitive (to me) thing to do here would be would be to have the Lreconstruct and Lsparsity terms, but not the Laux term, since the point of πgate is to tell you which features should be active, not to itself provide good feature coefficients for reconstructing x. I can sort of see how not including this term might result in the coordinates of πgate all being extremely small (but barely positive when it’s appropriate to use a feature), such that the sparsity term doesn’t contribute much to the loss. Is that what goes wrong? Are there ablation experiments you can report for this? If so, including this Laux term still currently seems to me like a pretty unprincipled way to deal with this—can the authors provide any flavor here?
Here are two ways that I’ve come up with for thinking about this loss function—let me know if either of these are on the right track. Let fgate,ReLU denote the gated encoder, but with a ReLU activation instead of Heaviside. Note then that fgate,ReLU is just the standard SAE encoder from Towards Monosemanticity.
Perspective 1: The usual loss from Towards Monosemanticity for training SAEs is ∥x−^x(fgate,ReLU(x))∥22+λ∥fgate,ReLU(x)∥1 (this is the same as your Lsparsity and Laux up to the detaching thing). But now you have this magnitude network which needs to get a gradient signal. Let’s do that by adding an additional term ∥x−^x(~f(x))∥22 -- your Lreconstruction. So under this perspective, it’s the reconstruction term which is new, with the sparsity and auxiliary terms being carried over from the usual way of doing things.
Perspective 2 (h/t Jannik Brinkmann): let’s just add together the usual Towards Monosemanticity loss function for both the usual architecture and the new modified archiecture: L=Lreconstruction(~f)+Lreconstruction(~f)+Lsparsity(fgate,ReLU)+Lsparsity(fgate,ReLU).
However, the gradients with respect to the second term in this sum vanish because of the use of the Heaviside, so the gradient with respect to this loss is the same as the gradient with respect to the loss you actually used.
Possibly I’m missing something, but if you don’t have Laux, then the only gradients to Wgate and bgate come from Lsparsity (the binarizing Heaviside activation function kills gradients from Lreconstruct), and so πgate would be always non-positive to get perfect zero sparsity loss. (That is, if you only optimize for L1 sparsity, the obvious solution is “none of the features are active”.)
(You could use a smooth activation function as the gate, e.g. an element-wise sigmoid, and then you could just stick with Lincorrect from the beginning of Section 3.2.2.)
Ah thanks, you’re totally right—that mostly resolves my confusion. I’m still a little bit dissatisfied, though, because the Laux term is optimizing for something that we don’t especially want (i.e. for ^x(ReLU(πgated(x)) to do a good job of reconstructing x). But I do see how you do need to have some sort of a reconstruction-esque term that actually allows gradients to pass through to the gated network.
Yep, the intuition here indeed was that L1 penalised reconstruction seems to be okay for teaching a standard SAE’s encoder to detect which features are on (even if features get shrunk as a result), so that is effectively what this auxiliary loss is teaching the gate sub-layer to do, alongside the sparsity penalty. (The key difference being we freeze the decoder in the auxiliary task, which the ablation study shows helps performance.) Maybe to put it another way, this was an auxiliary task that we had good evidence would teach the gate sublayer to detect active features reasonably well, and it turned out to give good results in practice. It’s totally possible though that there are better auxiliary tasks (or even completely different loss functions) out there that we’ve not explored.
I’m a bit perplexed by the choice of loss function for training GSAEs (given by equation (8) in the paper). The intuitive (to me) thing to do here would be would be to have the Lreconstruct and Lsparsity terms, but not the Laux term, since the point of πgate is to tell you which features should be active, not to itself provide good feature coefficients for reconstructing x. I can sort of see how not including this term might result in the coordinates of πgate all being extremely small (but barely positive when it’s appropriate to use a feature), such that the sparsity term doesn’t contribute much to the loss. Is that what goes wrong? Are there ablation experiments you can report for this? If so, including this Laux term still currently seems to me like a pretty unprincipled way to deal with this—can the authors provide any flavor here?
Here are two ways that I’ve come up with for thinking about this loss function—let me know if either of these are on the right track. Let fgate,ReLU denote the gated encoder, but with a ReLU activation instead of Heaviside. Note then that fgate,ReLU is just the standard SAE encoder from Towards Monosemanticity.
Perspective 1: The usual loss from Towards Monosemanticity for training SAEs is ∥x−^x(fgate,ReLU(x))∥22+λ∥fgate,ReLU(x)∥1 (this is the same as your Lsparsity and Laux up to the detaching thing). But now you have this magnitude network which needs to get a gradient signal. Let’s do that by adding an additional term ∥x−^x(~f(x))∥22 -- your Lreconstruction. So under this perspective, it’s the reconstruction term which is new, with the sparsity and auxiliary terms being carried over from the usual way of doing things.
Perspective 2 (h/t Jannik Brinkmann): let’s just add together the usual Towards Monosemanticity loss function for both the usual architecture and the new modified archiecture: L=Lreconstruction(~f)+Lreconstruction(~f)+Lsparsity(fgate,ReLU)+Lsparsity(fgate,ReLU).
However, the gradients with respect to the second term in this sum vanish because of the use of the Heaviside, so the gradient with respect to this loss is the same as the gradient with respect to the loss you actually used.
Possibly I’m missing something, but if you don’t have Laux, then the only gradients to Wgate and bgate come from Lsparsity (the binarizing Heaviside activation function kills gradients from Lreconstruct), and so πgate would be always non-positive to get perfect zero sparsity loss. (That is, if you only optimize for L1 sparsity, the obvious solution is “none of the features are active”.)
(You could use a smooth activation function as the gate, e.g. an element-wise sigmoid, and then you could just stick with Lincorrect from the beginning of Section 3.2.2.)
Ah thanks, you’re totally right—that mostly resolves my confusion. I’m still a little bit dissatisfied, though, because the Laux term is optimizing for something that we don’t especially want (i.e. for ^x(ReLU(πgated(x)) to do a good job of reconstructing x). But I do see how you do need to have some sort of a reconstruction-esque term that actually allows gradients to pass through to the gated network.
Yep, the intuition here indeed was that L1 penalised reconstruction seems to be okay for teaching a standard SAE’s encoder to detect which features are on (even if features get shrunk as a result), so that is effectively what this auxiliary loss is teaching the gate sub-layer to do, alongside the sparsity penalty. (The key difference being we freeze the decoder in the auxiliary task, which the ablation study shows helps performance.) Maybe to put it another way, this was an auxiliary task that we had good evidence would teach the gate sublayer to detect active features reasonably well, and it turned out to give good results in practice. It’s totally possible though that there are better auxiliary tasks (or even completely different loss functions) out there that we’ve not explored.