This is a very good explanation of why SAE’s incentivize feature combinatorics. Nice! I hadn’t thought about the tradeoff between the MSE-reduction for learning a rare feature & the L1-reduction for learning a common feature combination.
Freezing already learned features to iteratively learn more and more features could work. In concrete details, I think you would: 1. Learn an initial SAE w/ a much lower L0 (higher l1-alpha) than normally desired. 2. Learn a new SAE to predict the residual of (1), so the MSE would be only on what (1) messed up predicting. The l1 would also only be on this new SAE (since the other is frozen). You would still learn a new decoder-bias which should just be added on to the old one. 3. Combine & repeat until desired losses are obtained
There are at least 3 hyperparameters here to tune: L1-alpha (and do you keep it the same or try to have smaller number of features per iteration?), how many tokens to train on each (& I guess if you should repeat data?), & how many new features to add each iteration.
I believe the above should avoid problems. For example, suppose your first iteration perfectly reconstructs a datapoint, then the new SAE is incentivized to have low L1 but not activating at all for those datapoints.
Thanks! Yeah I think those steps make sense for the iterative process, but I’m not sure if you’re proposing that would tackle the problem of feature combinations by itself? I’m still imagining it would require orthogonality regularisation with some weighting
This is a very good explanation of why SAE’s incentivize feature combinatorics. Nice! I hadn’t thought about the tradeoff between the MSE-reduction for learning a rare feature & the L1-reduction for learning a common feature combination.
Freezing already learned features to iteratively learn more and more features could work. In concrete details, I think you would:
1. Learn an initial SAE w/ a much lower L0 (higher l1-alpha) than normally desired.
2. Learn a new SAE to predict the residual of (1), so the MSE would be only on what (1) messed up predicting. The l1 would also only be on this new SAE (since the other is frozen). You would still learn a new decoder-bias which should just be added on to the old one.
3. Combine & repeat until desired losses are obtained
There are at least 3 hyperparameters here to tune:
L1-alpha (and do you keep it the same or try to have smaller number of features per iteration?), how many tokens to train on each (& I guess if you should repeat data?), & how many new features to add each iteration.
I believe the above should avoid problems. For example, suppose your first iteration perfectly reconstructs a datapoint, then the new SAE is incentivized to have low L1 but not activating at all for those datapoints.
Thanks! Yeah I think those steps make sense for the iterative process, but I’m not sure if you’re proposing that would tackle the problem of feature combinations by itself? I’m still imagining it would require orthogonality regularisation with some weighting