Kind of confused on why the KL-only e2e SAE have worse CE than e2e+downstream across dictionary size:
This is true for layers 2 & 6. I’m unsure if this means that training for KL directly is harder/unstable, and the intermediate MSE is a useful prior, or if this is a difference in KL vs CE (ie the e2e does in fact do better on KL but worse on CE than e2e+downstream).
Here’s a wandb report that includes plots for the KL divergence. e2e+downstream indeed performs better for layer 2. So it’s possible that intermediate losses might help training a little. But I wouldn’t be surprised if better hyperparams eliminated this difference; we put more effort into optimising the SAE_local hyperparams rather than the SAE_e2e and SAE_e2e+ds hyperparams.
Kind of confused on why the KL-only e2e SAE have worse CE than e2e+downstream across dictionary size:
This is true for layers 2 & 6. I’m unsure if this means that training for KL directly is harder/unstable, and the intermediate MSE is a useful prior, or if this is a difference in KL vs CE (ie the e2e does in fact do better on KL but worse on CE than e2e+downstream).
Here’s a wandb report that includes plots for the KL divergence. e2e+downstream indeed performs better for layer 2. So it’s possible that intermediate losses might help training a little. But I wouldn’t be surprised if better hyperparams eliminated this difference; we put more effort into optimising the SAE_local hyperparams rather than the SAE_e2e and SAE_e2e+ds hyperparams.