MATS provides a generous compute stipend, and towards the end of the program we found we had some unspent compute. To let this not go to waste, we trained batch topk Matryoshka SAEs on all residual stream layers of Gemma-2-2b, Gemma-2-9b, and Gemma-3-1b, and are now releasing them publicly. The hyperparams for these SAEs were not aggressively optimized, but they should hopefully be decent. Below we describe our rationale for how these SAEs were trained and why, and the stats for each SAE. Key decisions:
We use more narrow inner widths than in the original Matryoshka SAEs work, and increase each width by a larger amount. We do this to make it easier to study the highest-frequency features of the model.
We include standard and snap loss variants for Gemma-2-2b. Snap loss and the rationale behind it is described in our Feature Hedging post. There is probably not much practical difference between the snap and standard versions of the SAEs.
We do not stop gradients between Matryoshka layers. We find in toy models that hedging and absorption pull the encoder in opposite directions, and this helps moderate the severity of feature hedging in Matryoshka SAE inner layers.
I don’t care about any of that, just give me the SAEs!
You can load all of the SAEs using SAELens via the following releases:
Gemma-2-2b snap loss matryoshka: gemma-2-2b-res-snap-matryoshka-dc
Gemma-2-2b standard matryoshka: gemma-2-2b-res-matryoshka-dc
Gemma-2-9b standard matryoshka: gemma-2-9b-res-matryoshka-dc
Gemma-3-1b standard matroshka: gemma-3-1b-res-matryoshka-dc
For each release, the SAE ID is just the corresponding Transformer Lens post residual stream hook point, e.g.blocks.5.hook_resid_post for the layer 5 residual stream SAE.
Each SAE can be loaded in SAELens as follows:
from sae_lens import SAE
sae = SAE.from_pretrained("<release>", "<sae_id>")[0]
For instance, to load the layer 5 snap variant SAE for gemma-2-2b, this would look like the following:
Matryoshka SAEs should be much better than standard SAEs at finding general, high-frequency concepts like parts of speech. In standard SAEs, latents tracking these concepts will get shot to pieces by feature absorption as they co-occur with so many other concepts. As Matryoshka SAEs should be much more resilient to absorption, we thus expect to find more meaningful high-density latents in Matryoshka SAEs (although these latents may be messed up by feature hedging instead). For instance, here’s a high-density first-layer latent from layer 12 of Gemma-2-2b, which appears to (very noisily) perform a grammatical function similar to Treebank’s IN (Preposition or subordinating conjunction) part of speech.
Higher frequency concepts should be concentrated in earlier latent indices. The highest frequency concepts should be in latents 0-127, then the next highest frequency should be in latents 128-511, etc...
Training info
All Matryoshka SAEs in this release are trained on 750M tokens from the Pile using a modified version of SAELens. The SAEs are all 32k width with the following Matryoshka levels: 128, 512, 2048, 8192, and 32768. We included two layers (128, 512) that are much narrower than the model residual stream to make it easier to study what the first features are that the SAE learns. These are all batch top-k SAEs, following the original Matryoshka SAEs work. We largely did not optimize hyperparams for these SAEs, so it’s likely possible to squeeze out more performance from the SAE with optimized choices of learning rate and more training tokens, but hopefully these SAEs should be decent.
Snap loss
One of the notable components of this release is the addition of snap loss variants of all SAEs for Gemma-2-2b. Snap loss is described in our post on Feature Hedging, and involves switching the reconstruction loss of the SAE from MSE to L2 mid-way through training. Practically, we don’t see much difference in SAEs trained on LLMs using snap loss, but are releasing these regardless in case others are curious to investigate the effect of snap loss, as we have the SAE trained anyway. If you notice a meaningful difference in practice between the snap loss and standard variants of these SAEs, please let us know!
Balancing absorption and hedging
Intuitively, it might seem like we’d want the inner layers of Matryoshka SAEs to be insulated from gradients from outer layers. Outer layers will pull the inner latents towards absorption, which defeats the purpose of a Matryoshka SAE! However, in toy models, hedging and absorption have opposite effects on the SAE encoder, so allowing some absorption pressure can help counteract the hedging of the SAE and improve performance. We notice that the dictionary_learning implementation of matryoshka SAEs also does not stop gradients between layers, and likely this is because stopping gradients causes hedging to mess up the SAE more severly.
For a further investigation of balancing hedging and absorption in Matryoshka SAEs, check out this colab.
We suspect that it may be possible to intentionally balance hedging with absorption in a more optimal way, and we plan to investigate this in future work.
SAEs and stats
Below we list all the SAEs trained along with some core stats.
Gemma-2-2b
We trained both snap and standard variants of SAEs for Gemma-2-2b. These SAEs have the relase ID gemma-2-2b-res-snap-matryoshka-dc for snap-loss variant, and gemma-2-2b-res-matryoshka-dc for the standard variant.
Snap Matryoshka SAEs
layer
SAE ID
width
l0
explained variance
0
blocks.0.hook_resid_post
32768
40
0.919964
1
blocks.1.hook_resid_post
32768
40
0.863969
2
blocks.2.hook_resid_post
32768
40
0.858767
3
blocks.3.hook_resid_post
32768
40
0.815844
4
blocks.4.hook_resid_post
32768
40
0.821094
5
blocks.5.hook_resid_post
32768
40
0.797083
6
blocks.6.hook_resid_post
32768
40
0.79815
7
blocks.7.hook_resid_post
32768
40
0.78946
8
blocks.8.hook_resid_post
32768
40
0.779236
9
blocks.9.hook_resid_post
32768
40
0.759022
10
blocks.10.hook_resid_post
32768
40
0.743998
11
blocks.11.hook_resid_post
32768
40
0.731758
12
blocks.12.hook_resid_post
32768
40
0.725974
13
blocks.13.hook_resid_post
32768
40
0.727936
14
blocks.14.hook_resid_post
32768
40
0.727065
15
blocks.15.hook_resid_post
32768
40
0.757408
16
blocks.16.hook_resid_post
32768
40
0.751874
17
blocks.17.hook_resid_post
32768
40
0.763654
18
blocks.18.hook_resid_post
32768
40
0.77644
19
blocks.19.hook_resid_post
32768
40
0.768622
20
blocks.20.hook_resid_post
32768
40
0.761658
21
blocks.21.hook_resid_post
32768
40
0.765593
22
blocks.22.hook_resid_post
32768
40
0.741098
23
blocks.23.hook_resid_post
32768
40
0.729718
24
blocks.24.hook_resid_post
32768
40
0.754838
Standard Matryoshka SAEs
layer
SAE ID
width
l0
explained variance
0
blocks.0.hook_resid_post
32768
40
0.91832
1
blocks.1.hook_resid_post
32768
40
0.863454
2
blocks.2.hook_resid_post
32768
40
0.841324
3
blocks.3.hook_resid_post
32768
40
0.814794
4
blocks.4.hook_resid_post
32768
40
0.820418
5
blocks.5.hook_resid_post
32768
40
0.796252
6
blocks.6.hook_resid_post
32768
40
0.797322
7
blocks.7.hook_resid_post
32768
40
0.787601
8
blocks.8.hook_resid_post
32768
40
0.779433
9
blocks.9.hook_resid_post
32768
40
0.75697
10
blocks.10.hook_resid_post
32768
40
0.745011
11
blocks.11.hook_resid_post
32768
40
0.732177
12
blocks.12.hook_resid_post
32768
40
0.726209
13
blocks.13.hook_resid_post
32768
40
0.719405
14
blocks.14.hook_resid_post
32768
40
0.719056
15
blocks.15.hook_resid_post
32768
40
0.756888
16
blocks.16.hook_resid_post
32768
40
0.742889
17
blocks.17.hook_resid_post
32768
40
0.757294
18
blocks.18.hook_resid_post
32768
40
0.76921
19
blocks.19.hook_resid_post
32768
40
0.766661
20
blocks.20.hook_resid_post
32768
40
0.760939
21
blocks.21.hook_resid_post
32768
40
0.759883
22
blocks.22.hook_resid_post
32768
40
0.740612
23
blocks.23.hook_resid_post
32768
40
0.729678
24
blocks.24.hook_resid_post
32768
40
0.747313
Gemma-2-9b
These SAEs have the release ID gemma-2-9b-res-matryoshka-dc.
layer
path
width
l0
explained variance
0
blocks.0.hook_resid_post
32768
60
0.942129
1
blocks.1.hook_resid_post
32768
60
0.900656
2
blocks.2.hook_resid_post
32768
60
0.869154
3
blocks.3.hook_resid_post
32768
60
0.84077
4
blocks.4.hook_resid_post
32768
60
0.816605
5
blocks.5.hook_resid_post
32768
60
0.826656
6
blocks.6.hook_resid_post
32768
60
0.798281
7
blocks.7.hook_resid_post
32768
60
0.796018
8
blocks.8.hook_resid_post
32768
60
0.790385
9
blocks.9.hook_resid_post
32768
60
0.775052
10
blocks.10.hook_resid_post
32768
60
0.756327
11
blocks.11.hook_resid_post
32768
60
0.741264
12
blocks.12.hook_resid_post
32768
60
0.718319
13
blocks.13.hook_resid_post
32768
60
0.714065
14
blocks.14.hook_resid_post
32768
60
0.709635
15
blocks.15.hook_resid_post
32768
60
0.706622
16
blocks.16.hook_resid_post
32768
60
0.687879
17
blocks.17.hook_resid_post
32768
60
0.695821
18
blocks.18.hook_resid_post
32768
60
0.691723
19
blocks.19.hook_resid_post
32768
60
0.690914
20
blocks.20.hook_resid_post
32768
60
0.684599
21
blocks.21.hook_resid_post
32768
60
0.691355
22
blocks.22.hook_resid_post
32768
60
0.705531
23
blocks.23.hook_resid_post
32768
60
0.702293
24
blocks.24.hook_resid_post
32768
60
0.707655
25
blocks.25.hook_resid_post
32768
60
0.721022
26
blocks.26.hook_resid_post
32768
60
0.721717
27
blocks.27.hook_resid_post
32768
60
0.745809
28
blocks.28.hook_resid_post
32768
60
0.753267
29
blocks.29.hook_resid_post
32768
60
0.76466
30
blocks.30.hook_resid_post
32768
60
0.763025
31
blocks.31.hook_resid_post
32768
60
0.765932
32
blocks.32.hook_resid_post
32768
60
0.760822
33
blocks.33.hook_resid_post
32768
60
0.73323
34
blocks.34.hook_resid_post
32768
60
0.746912
35
blocks.35.hook_resid_post
32768
60
0.738031
36
blocks.36.hook_resid_post
32768
60
0.730805
37
blocks.37.hook_resid_post
32768
60
0.722875
38
blocks.38.hook_resid_post
32768
60
0.715494
39
blocks.39.hook_resid_post
32768
60
0.7044
40
blocks.40.hook_resid_post
32768
60
0.711277
Gemma-3-1b
These SAEs have the release ID gemma-3-1b-res-matryoshka-dc.
A Bunch of Matryoshka SAEs
This work was done as part of MATS 7.0.
MATS provides a generous compute stipend, and towards the end of the program we found we had some unspent compute. To let this not go to waste, we trained batch topk Matryoshka SAEs on all residual stream layers of Gemma-2-2b, Gemma-2-9b, and Gemma-3-1b, and are now releasing them publicly. The hyperparams for these SAEs were not aggressively optimized, but they should hopefully be decent. Below we describe our rationale for how these SAEs were trained and why, and the stats for each SAE. Key decisions:
We use more narrow inner widths than in the original Matryoshka SAEs work, and increase each width by a larger amount. We do this to make it easier to study the highest-frequency features of the model.
We include standard and snap loss variants for Gemma-2-2b. Snap loss and the rationale behind it is described in our Feature Hedging post. There is probably not much practical difference between the snap and standard versions of the SAEs.
We do not stop gradients between Matryoshka layers. We find in toy models that hedging and absorption pull the encoder in opposite directions, and this helps moderate the severity of feature hedging in Matryoshka SAE inner layers.
I don’t care about any of that, just give me the SAEs!
You can load all of the SAEs using SAELens via the following releases:
Gemma-2-2b snap loss matryoshka:
gemma-2-2b-res-snap-matryoshka-dc
Gemma-2-2b standard matryoshka:
gemma-2-2b-res-matryoshka-dc
Gemma-2-9b standard matryoshka:
gemma-2-9b-res-matryoshka-dc
Gemma-3-1b standard matroshka:
gemma-3-1b-res-matryoshka-dc
For each release, the SAE ID is just the corresponding Transformer Lens post residual stream hook point, e.g.
blocks.5.hook_resid_post
for the layer 5 residual stream SAE.Each SAE can be loaded in SAELens as follows:
For instance, to load the layer 5 snap variant SAE for gemma-2-2b, this would look like the following:
Neuronpedia
Neuronpedia has generously hosted some of these SAEs, with more coming soon. Check them out at: https://www.neuronpedia.org/res-matryoshka-dc.
Matryoshka SAEs should be much better than standard SAEs at finding general, high-frequency concepts like parts of speech. In standard SAEs, latents tracking these concepts will get shot to pieces by feature absorption as they co-occur with so many other concepts. As Matryoshka SAEs should be much more resilient to absorption, we thus expect to find more meaningful high-density latents in Matryoshka SAEs (although these latents may be messed up by feature hedging instead). For instance, here’s a high-density first-layer latent from layer 12 of Gemma-2-2b, which appears to (very noisily) perform a grammatical function similar to Treebank’s IN (Preposition or subordinating conjunction) part of speech.
Higher frequency concepts should be concentrated in earlier latent indices. The highest frequency concepts should be in latents 0-127, then the next highest frequency should be in latents 128-511, etc...
Training info
All Matryoshka SAEs in this release are trained on 750M tokens from the Pile using a modified version of SAELens. The SAEs are all 32k width with the following Matryoshka levels: 128, 512, 2048, 8192, and 32768. We included two layers (128, 512) that are much narrower than the model residual stream to make it easier to study what the first features are that the SAE learns. These are all batch top-k SAEs, following the original Matryoshka SAEs work. We largely did not optimize hyperparams for these SAEs, so it’s likely possible to squeeze out more performance from the SAE with optimized choices of learning rate and more training tokens, but hopefully these SAEs should be decent.
Snap loss
One of the notable components of this release is the addition of snap loss variants of all SAEs for Gemma-2-2b. Snap loss is described in our post on Feature Hedging, and involves switching the reconstruction loss of the SAE from MSE to L2 mid-way through training. Practically, we don’t see much difference in SAEs trained on LLMs using snap loss, but are releasing these regardless in case others are curious to investigate the effect of snap loss, as we have the SAE trained anyway. If you notice a meaningful difference in practice between the snap loss and standard variants of these SAEs, please let us know!
Balancing absorption and hedging
Intuitively, it might seem like we’d want the inner layers of Matryoshka SAEs to be insulated from gradients from outer layers. Outer layers will pull the inner latents towards absorption, which defeats the purpose of a Matryoshka SAE! However, in toy models, hedging and absorption have opposite effects on the SAE encoder, so allowing some absorption pressure can help counteract the hedging of the SAE and improve performance. We notice that the dictionary_learning implementation of matryoshka SAEs also does not stop gradients between layers, and likely this is because stopping gradients causes hedging to mess up the SAE more severly.
For a further investigation of balancing hedging and absorption in Matryoshka SAEs, check out this colab.
We suspect that it may be possible to intentionally balance hedging with absorption in a more optimal way, and we plan to investigate this in future work.
SAEs and stats
Below we list all the SAEs trained along with some core stats.
Gemma-2-2b
We trained both snap and standard variants of SAEs for Gemma-2-2b. These SAEs have the relase ID
gemma-2-2b-res-snap-matryoshka-dc
for snap-loss variant, andgemma-2-2b-res-matryoshka-dc
for the standard variant.Snap Matryoshka SAEs
Standard Matryoshka SAEs
Gemma-2-9b
These SAEs have the release ID
gemma-2-9b-res-matryoshka-dc
.Gemma-3-1b
These SAEs have the release ID
gemma-3-1b-res-matryoshka-dc
.