Awesome work with this! Definitely looks like a big improvement over standard SAEs for absorption. Some questions/thoughts:
In the decoder cos sim plot, it looks like there’s still some slight mixing of features in co-occurring latent groups including some slight negative cos sim, although definitely a lot better than in the standard SAE. Given the underlying features are orthogonal, I’m curious why the Matryoshka SAE doesn’t fully drive this to 0 and perfectly recover the underlying true features? Is it due to the sampling, so there’s still some chance for the SAE to learn some absorption-esque feature mixes when the SAE latent isn’t sampled? If there was no sampling and each latent group had its loss calculated each step (I know this is really inefficient in practice), would the SAE perfectly recover the true features?
It looks like Matryoshka SAEs will solve absorption as long as the parent feature in the hierarchy is learned before the child feature, but this doesn’t seem like it’s guaranteed to be the case. If the child feature happens to fire with much higher magnitude than the parent, then I would suspect the SAE would learn the child latent first to minimize expected MSE loss, and end up with absorption still. E.g. if a parent feature fires with probability 0.3 and magnitude 2.0 (expected MSE = 0.3 * 2.0^2 = 1.2), and a child feature fires with probability 0.15 but magnitude 10.0 (expected MSE = 0.15 * 10^2 = 15.0), I would expect the SAE would learn the child feature before the parent, and merge the parent representation into the child, resulting in absorption. In real LLMs, this might potentially never happen though so possibly not an issue, but could be something to look out for when training Matryoshka SAEs on real LLMs.
Even with all possible prefixes included in every batch the toy model learns the same small mixing between parent and children (this was best out of 2, for the first run the matryoshka didn’t represent one of the features): https://sparselatents.com/matryoshka_toy_all_prefixes.png
Here’s a hypothesis that could explain most of this mixing. If the hypothesis is true, then even if every possible prefix is included in every batch, there will still be mixing.
Hypothesis:
Regardless of the number of prefixes, there will be some prefix loss terms where 1. a parent and child feature are active 2. the parent latent is included in the prefix 3. the child latent isn’t included in the prefix.
The MSE loss in these prefix loss terms is pretty large because the child feature isn’t represented at all. This nudges the parent to slightly represent all of its children a bit.
To compensate for this, if a child feature is active and the child latent is included the prefix, it undoes the parent decoder vector’s contribution to the features of the parent’s other children.
This could explain these weird properties of the heatmap: - Parent decoder vector has small positive cosine similarity with child features - Child decoder vectors have small negative cosine similarity with other child features
Still unexplained by this hypothesis: - Child decoder vectors have very small negative cosine similarity with the parent feature.
I tried digging into this some more and think I have an idea what’s going on. As I understand it, the base assumption for why Matryoshka SAE should solve absorption is that a narrow SAE should perfectly reconstruct parent features in a hierarchy, so then absorption patterns can’t arise between child and parent features. However, it seems like this assumption is not correct: narrow SAEs sill learn messed up latents when there’s co-occurrence between parent and child features in a hierarchy, and this messes up what the Matryoshka SAE learns.
Apologies for the long comment, this might make more sense as its own post potentially. I’m curious to get others thoughts on this—it’s also possible I’m doing something dumb.
The problem: Matryoshka latents don’t perfectly match true features
In the post, the Matryoshka latents seem to have the following problematic properties:
- The latent tracking a parent feature contains components of child features - The latents tracking child features have negative components of each other child feature
The setup: simplified hierarchical features
I tried to investigate this using a simpler version of the setup in this post, focusing on a single parent/child relationship between latents. This is like a zoomed-in version on a single set of parent/child features. Our setup has 3 true features in a hierarchy as below:
These features have higher firing probabilities compared to the setup in the original post to make the trends highlighted more obvious. All features fire with magnitude 1.0 and have a 20d representation with no superposition (all features are mutually orthogonal).
Simplified Matryoshka SAE
I used a simpler Matryoshka SAE that doesn’t use feature sampling or reshuffling of latents and doesn’t take the log of losses. Since we already know the hierarchy of the underlying features in this setup, I just used a Matryoshka SAE with a single inner SAE width of 1 latent to track the 1 parent feature, and the outer SAE width of 3 to match the number of true features. So the Matryoshka SAE sizes are as below:
size 0: latents [0]
size 1: latents [0, 1, 2]
The cosine similarities between the encoder and decoder of the Matryoshka SAE and the true features is shown below:
The Matryoshka decoder matches what we saw in the original post: the latent tracking the parent feature has positive cosine sim with the child features, and the latents tracking the child features have negative cosine sim with the other child feature. Our matryoshka inner SAE consisting of just latent 0 does track the parent feature as we expected though! What’s going on here? How is it possible for the inner Matryoshka SAE to represent a merged version of the parent and child features?
Narrow SAEs do not correctly reconstruct parent features
The core idea behind Matryoshka SAEs is that in a narrower SAE, the SAE should learn a clean representation of parent features despite co-ooccurrence with child features. Once we have a clean representation of a parent feature in a hierarchy, adding child latents to the SAE should not allow any absorption.
Surprisingly, this assumption is incorrect: narrow SAEs merge child representations into the parent latent.
I tried training a standard SAE with a single latent on our toy example, expecting that the 1-latent SAE would learn only the parent feature without any signs of absorption. Below is the plot of the cosine similarities between the SAE encoder and decoder with the true features.
This single-latent SAE learns a representation that merges the representation of the child features into the parent latent, exactly how we saw in our Matryoshka SAE and in the original post’s results! Our narrow SAE is not learning the correct representation of feature 0 as we would hope. Instead, it’s learning feature 0 + weaker representations of child features 1 and 2.
Why does this happen?
Likely this reduces MSE loss compared with learning the actual correct representation of feature 0 on its own. When there’s fewer latents than features, the SAE always has to accept some MSE error, and this behavior of merging in some of the child features into the parent latent likely reduces MSE loss compared with learning the actual parent feature 0 on its own.
What does this mean for Matryoshka SAEs?
This issue should affect any Matryoshka SAE, since the base assumption underlying Matryoshka SAEs is that a narrow SAE will correctly represent general parent features without any issues due to co-occurrence from specific child features. Since that assumption is not correct, we should not expect a Matryoshka SAE to completely fix absorption issues. I would expect that the topk SAEs from https://www.lesswrong.com/posts/rKM9b6B2LqwSB5ToN/learning-multi-level-features-with-matryoshka-saes would also suffer from this problem, although I didn’t test that in this toy setting since topk SAEs are tricker to evaluate in toy settings (it’s not obvious what K to pick).
It’s possible the issues shown in this toy setting are more extreme than in a real LLM since the firing probabilities of the features may be higher than many features in a real LLM. That said, it’s hard to say anything concrete about the firing probabilities of features in real LLMs since we have no ground truth data on true LLM features.
Awesome work with this! Definitely looks like a big improvement over standard SAEs for absorption. Some questions/thoughts:
In the decoder cos sim plot, it looks like there’s still some slight mixing of features in co-occurring latent groups including some slight negative cos sim, although definitely a lot better than in the standard SAE. Given the underlying features are orthogonal, I’m curious why the Matryoshka SAE doesn’t fully drive this to 0 and perfectly recover the underlying true features? Is it due to the sampling, so there’s still some chance for the SAE to learn some absorption-esque feature mixes when the SAE latent isn’t sampled? If there was no sampling and each latent group had its loss calculated each step (I know this is really inefficient in practice), would the SAE perfectly recover the true features?
It looks like Matryoshka SAEs will solve absorption as long as the parent feature in the hierarchy is learned before the child feature, but this doesn’t seem like it’s guaranteed to be the case. If the child feature happens to fire with much higher magnitude than the parent, then I would suspect the SAE would learn the child latent first to minimize expected MSE loss, and end up with absorption still. E.g. if a parent feature fires with probability 0.3 and magnitude 2.0 (expected MSE = 0.3 * 2.0^2 = 1.2), and a child feature fires with probability 0.15 but magnitude 10.0 (expected MSE = 0.15 * 10^2 = 15.0), I would expect the SAE would learn the child feature before the parent, and merge the parent representation into the child, resulting in absorption. In real LLMs, this might potentially never happen though so possibly not an issue, but could be something to look out for when training Matryoshka SAEs on real LLMs.
Even with all possible prefixes included in every batch the toy model learns the same small mixing between parent and children (this was best out of 2, for the first run the matryoshka didn’t represent one of the features): https://sparselatents.com/matryoshka_toy_all_prefixes.png
Here’s a hypothesis that could explain most of this mixing. If the hypothesis is true, then even if every possible prefix is included in every batch, there will still be mixing.
Hypothesis:
This could explain these weird properties of the heatmap:
- Parent decoder vector has small positive cosine similarity with child features
- Child decoder vectors have small negative cosine similarity with other child features
Still unexplained by this hypothesis:
- Child decoder vectors have very small negative cosine similarity with the parent feature.
I tried digging into this some more and think I have an idea what’s going on. As I understand it, the base assumption for why Matryoshka SAE should solve absorption is that a narrow SAE should perfectly reconstruct parent features in a hierarchy, so then absorption patterns can’t arise between child and parent features. However, it seems like this assumption is not correct: narrow SAEs sill learn messed up latents when there’s co-occurrence between parent and child features in a hierarchy, and this messes up what the Matryoshka SAE learns.
I did this investigation in the following colab: https://colab.research.google.com/drive/1sG64FMQQcRBCNGNzRMcyDyP4M-Sv-nQA?usp=sharing
Apologies for the long comment, this might make more sense as its own post potentially. I’m curious to get others thoughts on this—it’s also possible I’m doing something dumb.
The problem: Matryoshka latents don’t perfectly match true features
In the post, the Matryoshka latents seem to have the following problematic properties:
- The latent tracking a parent feature contains components of child features
- The latents tracking child features have negative components of each other child feature
The setup: simplified hierarchical features
I tried to investigate this using a simpler version of the setup in this post, focusing on a single parent/child relationship between latents. This is like a zoomed-in version on a single set of parent/child features. Our setup has 3 true features in a hierarchy as below:
These features have higher firing probabilities compared to the setup in the original post to make the trends highlighted more obvious. All features fire with magnitude 1.0 and have a 20d representation with no superposition (all features are mutually orthogonal).
Simplified Matryoshka SAE
I used a simpler Matryoshka SAE that doesn’t use feature sampling or reshuffling of latents and doesn’t take the log of losses. Since we already know the hierarchy of the underlying features in this setup, I just used a Matryoshka SAE with a single inner SAE width of 1 latent to track the 1 parent feature, and the outer SAE width of 3 to match the number of true features. So the Matryoshka SAE sizes are as below:
The cosine similarities between the encoder and decoder of the Matryoshka SAE and the true features is shown below:
The Matryoshka decoder matches what we saw in the original post: the latent tracking the parent feature has positive cosine sim with the child features, and the latents tracking the child features have negative cosine sim with the other child feature. Our matryoshka inner SAE consisting of just latent 0 does track the parent feature as we expected though! What’s going on here? How is it possible for the inner Matryoshka SAE to represent a merged version of the parent and child features?
Narrow SAEs do not correctly reconstruct parent features
The core idea behind Matryoshka SAEs is that in a narrower SAE, the SAE should learn a clean representation of parent features despite co-ooccurrence with child features. Once we have a clean representation of a parent feature in a hierarchy, adding child latents to the SAE should not allow any absorption.
Surprisingly, this assumption is incorrect: narrow SAEs merge child representations into the parent latent.
I tried training a standard SAE with a single latent on our toy example, expecting that the 1-latent SAE would learn only the parent feature without any signs of absorption. Below is the plot of the cosine similarities between the SAE encoder and decoder with the true features.
This single-latent SAE learns a representation that merges the representation of the child features into the parent latent, exactly how we saw in our Matryoshka SAE and in the original post’s results! Our narrow SAE is not learning the correct representation of feature 0 as we would hope. Instead, it’s learning feature 0 + weaker representations of child features 1 and 2.
Why does this happen?
Likely this reduces MSE loss compared with learning the actual correct representation of feature 0 on its own. When there’s fewer latents than features, the SAE always has to accept some MSE error, and this behavior of merging in some of the child features into the parent latent likely reduces MSE loss compared with learning the actual parent feature 0 on its own.
What does this mean for Matryoshka SAEs?
This issue should affect any Matryoshka SAE, since the base assumption underlying Matryoshka SAEs is that a narrow SAE will correctly represent general parent features without any issues due to co-occurrence from specific child features. Since that assumption is not correct, we should not expect a Matryoshka SAE to completely fix absorption issues. I would expect that the topk SAEs from https://www.lesswrong.com/posts/rKM9b6B2LqwSB5ToN/learning-multi-level-features-with-matryoshka-saes would also suffer from this problem, although I didn’t test that in this toy setting since topk SAEs are tricker to evaluate in toy settings (it’s not obvious what K to pick).
It’s possible the issues shown in this toy setting are more extreme than in a real LLM since the firing probabilities of the features may be higher than many features in a real LLM. That said, it’s hard to say anything concrete about the firing probabilities of features in real LLMs since we have no ground truth data on true LLM features.