I’m extremely excited by the work in SAEs and their potential for interpretability, however I think there is a subtle misalignment in the SAE architecture and loss function, and the actual desired objective function.
The SAE loss function is:
L(x;Wdec,bdec)=Ex[||x−^x||2+λ||f(x)||1], where ||f(x)||1=∑ifi(x) is the ℓ1-Norm.
or
L(x)=Ex[||x−Wdecf(x)−bdec||2+λ||f(x)||1]
I would argue that, however, what you are actually trying to solve is the sparse coding problem:
L(x;Wdec,bdec)=Ex[minf||x−Wdecf−bdec||2+λ||f||1]
where, importantly, the inner optimization is solved separately (including at runtime).
Sincef is an overcomplete basis, finding f∗ that minimizes the inner loop (also known as basis pursuit denoising[1] ) is a notoriously challenging problem, one which a single-layer encoder is underpowered to compute. The SAE’s encoder thus introduces a significant error ~fenc , which means that you are actual loss function is:
The magnitude of the errors would have to be determined empirically, but I suspect that it is enough to be a significant source of error..
There are a few things you could do reduce the error:
Ensuring that Wdec obeys the restricted isometry property[2] (i.e. a cap on the cosine similarity of decoder weights), or barring that, adding a term to your loss function that at least minimizes the cosine similarities.
Adding extra layers to your encoder, so it’s better at solving for f∗.
Empirical studies to see how large the feature error is / how much reconstruction error it is adding.
Interesting! You might be interested in a post from my team on inference-time optimization
It’s not clear to me what the right call here is though, because you want f to be something the model could extract. The encoder being so simple is in some ways a feature, not a bug—I wouldn’t want it to be eg a deep model, because the LLM can’t easily extract that!
Thanks for sharing that study. It looks like your team is already well-versed in this subject!
You wouldn’t want something that’s too hard to extract, but I think restricting yourself to a single encoder layer is too conservative—LLMs don’t have to be able to fully extract the information from a layer in a single step.
I’d be curious to see how much closer a two-layer encoder would get to the ITO results.
:Here’s my longer reply.
I’m extremely excited by the work in SAEs and their potential for interpretability, however I think there is a subtle misalignment in the SAE architecture and loss function, and the actual desired objective function.
The SAE loss function is:
L(x;Wdec,bdec)=Ex[||x−^x||2+λ||f(x)||1], where ||f(x)||1=∑ifi(x) is the ℓ1-Norm.
or
L(x)=Ex[||x−Wdecf(x)−bdec||2+λ||f(x)||1]
I would argue that, however, what you are actually trying to solve is the sparse coding problem:
L(x;Wdec,bdec)=Ex[minf||x−Wdecf−bdec||2+λ||f||1]
where, importantly, the inner optimization is solved separately (including at runtime).
Sincef is an overcomplete basis, finding f∗ that minimizes the inner loop (also known as basis pursuit denoising[1] ) is a notoriously challenging problem, one which a single-layer encoder is underpowered to compute. The SAE’s encoder thus introduces a significant error ~fenc , which means that you are actual loss function is:
L(x;Θ)=Ex[||x−Wdec(f∗+~fenc)−bdec||2+λ||f∗+~fenc||1]
The magnitude of the errors would have to be determined empirically, but I suspect that it is enough to be a significant source of error..
There are a few things you could do reduce the error:
Ensuring that Wdec obeys the restricted isometry property[2] (i.e. a cap on the cosine similarity of decoder weights), or barring that, adding a term to your loss function that at least minimizes the cosine similarities.
Adding extra layers to your encoder, so it’s better at solving for f∗.
Empirical studies to see how large the feature error is / how much reconstruction error it is adding.
https://epubs.siam.org/doi/abs/10.1137/S003614450037906X?casa_token=E-R-1D55k-wAAAAA:DB1SABlJH5NgtxkRlxpDc_4IOuJ4SjBm5-dLTeZd7J-pnTAA4VQQ2FJ6TfkRpZ3c93MNrpHddcI
http://www.numdam.org/item/10.1016/j.crma.2008.03.014.pdf
Interesting! You might be interested in a post from my team on inference-time optimization
It’s not clear to me what the right call here is though, because you want f to be something the model could extract. The encoder being so simple is in some ways a feature, not a bug—I wouldn’t want it to be eg a deep model, because the LLM can’t easily extract that!
Thanks for sharing that study. It looks like your team is already well-versed in this subject!
You wouldn’t want something that’s too hard to extract, but I think restricting yourself to a single encoder layer is too conservative—LLMs don’t have to be able to fully extract the information from a layer in a single step.
I’d be curious to see how much closer a two-layer encoder would get to the ITO results.