I tried replicating your statistics using my own evaluation code (in evaluation.pyhere). I pseudo-randomly chose layer 1 and layer 7. Sadly, my results look rather different from yours:
Layer
MSE Loss
% Variance Explained
L1
L0
% Alive
CE Reconstructed
1
0.11
92
44
17.5
54
5.95
7
1.1
82
137
65.4
95
4.29
Places where our metrics agree: L1 and L0.
Places where our metrics disagree, but probably for a relatively benign reason:
Percent variance explained: my numbers are slightly lower than yours, and from a brief skim of your code I think it’s because you’re calculating variance slightly incorrectly: you’re not subtracting off the activation’s mean before doing .pow(2).sum(-1). This will slightly overestimate the variance of the original activations, so probably also overestimate percent variance explained.
Percent alive: my numbers are slightly lower than yours, and this is probably because I determined whether neurons are alive on a (somewhat small) batch of 8192 tokens. So my number is probably an underestimate and yours is correct.
Our metrics disagree strongly on CE reconstructed, and this is a bit alarming. It means that either you have a bug which significantly underestimates reconstructed CE loss, or I have a bug which significantly overestimates it. I think I’m 50⁄50 on which it is. Note that according to my stats, your MSE loss is kinda bad, which would suggest that you should also have high CE reconstructed (especially when working with residual stream dictionaries! (in contrast to e.g. MLP dictionaries which are much more forgiving)).
Spitballing a possible cause: when computing CE loss, did you exclude padding tokens? If not, then it’s possible that many of the tokens on which you’re computing CE are padding tokens, which is artificially making your CE look extremely good.
Here is my code. You’ll need to pip install nnsight before running it. Many thanks to Caden Juang for implementing the UnifiedTransformer functionality in nnsight, which is a crazy Frankenstein marriage of nnsight and transformer_lens; it would have been very hard for me to attempt this replication without this feature.
Oh no. I’ll look into this and get back to you shortly. One obvious candidate is that I was reporting CE for some batch at the end of training that was very small and so the statistics likely had high variance and the last datapoint may have been fairly low. In retrospect I should have explicitly recalculated this again post training. However, I’ll take a deeper dive now to see what’s up.
I’ve run some of the SAE’s through more thorough eval code this morning (getting variance explained with the centring and calculating mean CE losses with more batches). As far as I can tell the CE loss is not that high at all and the MSE loss is quite low. I’m wondering whether you might be using the wrong hooks? These are resid_pre so layer 0 is just the embeddings and layer 1 is after the first transformer block and so on. One other possibility is that you are using a different dataset? I trained these SAEs on OpenWebText. I don’t much padding at all, that might be a big difference too. I’m curious to get to the bottom of this.
One sanity check I’ve done is just sampling from the model when using the SAE to reconstruct activations and it seems to be about as good, which I think rules out CE loss in the ranges you quote above.
For percent alive neurons a batch size of 8192 would be far too few to estimate dead neurons (since many neurons have a feature sparsity < 10**-3.
You’re absolutely right about missing the centreing in percent variance explained. I’ve estimated variance explained again for the same layers and get very similar results to what I had originally. I’ll make some updates to my code to produce CE score metrics that have less variance in the future at the cost of slightly more train time.
If we don’t find a simple answer I’m happy to run some more experiments but I’d guess an 80% probability that there’s a simple bug which would explain the difference in what you get. Rank order of most likely: Using the wrong activations, using datapoints with lots of padding, using a different dataset (I tried the pile and it wasn’t that bad either).
In the notebook I link in my original comment, I check that the activations I get out of nnsight are the same as the activations that come from transformer_lens. Together with the fact that our sparsity statistics broadly align, I’m guessing that the issue isn’t that I’m extracting different activations than you are.
Repeating my replication attempt with data from OpenWebText, I get this:
Layer
MSE Loss
% Variance Explained
L1
L0
% Alive
CE Reconstructed
1
0.069
95
40
15
46
6.45
7
0.81
86
125
59.2
96
4.38
Broadly speaking, same story as above, except that the MSE losses look better (still not great), and that the CE reconstructed looks very bad for layer 1.
I don’t much padding at all, that might be a big difference too.
Seems like there was a typo here—what do you mean?
Logan Riggs reports that he tried to replicate your results and got something more similar to you. I think Logan is making decisions about padding and tokenization more like the decisions you make, so it’s possible that the difference is down to something around padding and tokenization.
Possible next steps:
Can you report your MSE Losses (instead of just variance explained)?
Can you try to evaluate the residual stream dictionaries in the 5_32768 set released here? If you get CE reconstructed much better than mine, then it means that we’re computing CE reconstructed in different ways, where your way consistently reports better numbers. If you get CE reconstructed much worse than mine, then it might mean that there’s a translation error between our codebases (e.g. using different activations).
Another sanity check: when you compute CE loss using the same code that you use when computing CE loss when activations are reconstructed by the autoencoders, but instead of actually using the autoencoder you just plug the correct activations back in, do you get the same answer (~3.3) as when you evaluate CE loss normally?
MSE Losses were in the WandB report (screenshot below).
I’ve loaded in your weights for one SAE and I get very bad performance (high L0, high L1, and bad MSE Loss) at first.
It turns out that this is because my forward pass uses a tied decoder bias which is subtracted from the initial activations and added as part of the decoder forward pass. AFAICT, you don’t do this.
To verify this, I added the decoder bias to the activations of your SAE prior to running a forward pass with my code (to effectively remove the decoder bias subtraction from my method) and got reasonable results.
I’ve screenshotted the Towards Monosemanticity results which describes the tied decoder bias below as well.
I’d be pretty interested in knowing if my SAEs seem good now based on your evals :) Hopefully this was the only issue.
Note that I checked that our SAEs have the same input-output behavior in my linked colab notebook. I think I’m a bit confused why subtracting off the decoder bias had to be done explicitly in your code—maybe you used dictionary.encoder and dictionary.decoder instead of dictionary.encode and dictionary.decode? (Sorry, I know this is confusing.) ETA: Simple things I tried based on the hypothesis “one of us needs to shift our inputs by +/- the decoder bias” only made things worse, so I’m pretty sure that you had just initially converted my dictionaries into your infrastructure in a way that messed up the initial decoder bias, and therefore had to hand-correct it.
I note that the MSE Loss you reported for my dictionary actually is noticeably better than any of the MSE losses I reported for my residual stream dictionaries! Which layer was this? Seems like something to dig into.
Ahhh I see. Sorry I was way too hasty to jump at this as the explanation. Your code does use the tied decoder bias (and yeah, it was a little harder to read because of how your module is structured). It is strange how assuming that bug seemed to help on some of the SAEs but I ran my evals over all your residual stream SAE’s and it only worked for some / not others and certainly didn’t seem like a good explanation after I’d run it on more than one.
I’ve been talking to Logan Riggs who says he was able to load in my SAEs and saw fairly similar reconstruction performance to to me but that outside of the context length of 128 tokens, performance markedly decreases. He also mentioned your eval code uses very long prompts whereas mine limits to 128 tokens so this may be the main cause of the difference. Logan mentioned you had discussed this with him so I’m guessing you’ve got more details on this than I have? I’ll build some evals specifically to look at this in the future I think.
Scientifically, I am fairly surprised about the token length effect and want to try training on activations from much longer context sizes now. I have noticed (anecdotally) that the number of features I get sometimes increases over the prompt so an SAE trained on activations from shorter prompts are plausibly going to have a much easier time balancing reconstruction and sparsity, which might explain the generally lower MSE / higher reconstruction. Though we shouldn’t really compare between models and with different levels of sparsity as we’re likely to be at different locations on the pareto frontier.
One final note is that I’m excited to see whether performance on the first 128 tokens actually improves in SAEs trained on activations from > 128 token forward passes (since maybe the SAE becomes better in general).
Yep, as you say, @Logan Riggs figured out what’s going on here: you evaluated your reconstruction loss on contexts of length 128, whereas I evaluated on contexts of arbitrary length. When I restrict to context length 128, I’m able to replicate your results.
Here’s Logan’s plot for one of your dictionaries (not sure which)
and here’s my replication of Logan’s plot for your layer 1 dictionary
Interestingly, this does not happen for my dictionaries! Here’s the same plot but for my layer 1 residual stream output dictionary for pythia-70m-deduped
(Note that all three plots have a different y-axis scale.)
Why the difference? I’m not really sure. Two guesses:
The model: GPT2-small uses learned positional embeddings whereas Pythia models use rotary embeddings
The training: I train my autoencoders on variable-length sequences up to length 128; left padding is used to pad shorter sequences up to length 128. Maybe this makes a difference somehow.
In terms of standardization of which metrics to report, I’m torn. On one hand, for the task your dictionaries were trained on (reconstruction activations taken from length 128 sequences), they’re performing well and this should be reflected in the metrics. On the other hand, people should be aware that if they just plug your autoencoders into GPT2-small and start doing inference on inputs found in the wild, things will go off the rails pretty quickly. Maybe the answer is that CE diff should be reported both for sequences of the same length used in training and for arbitrary-length sequences?
The fact that Pythia generalizes to longer sequences but GPT-2 doesn’t isn’t very surprising to me—getting long context generalization to work is a key motivation for rotary, e.g. the original paper https://arxiv.org/abs/2104.09864
This comment is about why we were getting different MSE numbers. The answer is (mostly) benign—a matter of different scale factors. My parallel comment, which discusses why we were getting different CE diff numbers is the more important one.
When you compute MSE loss between some activations x and their reconstruction ^x, you divide by variance of x, as estimated from the data in a batch. I’ll note that this doesn’t seem like a great choice to me. Looking at the resulting training loss:
∥x−^x∥22/Var(x)+λ∥f∥1
where f is the encoding of x by the autoencoder and λ is the L1 regularization constant, we see that if you scale x by some constant α, this will have no effect on the first term, but will scale the second term by α. So if activations generically become larger in later layers, this will mean that the sparsity term becomes automatically more strongly weighted.
I think a more principled choice would be something like
∥x−^x∥2+λ∥f∥1
where we’re no longer normalizing by the variance, and are also using sqrt(MSE) instead of MSE. (This is what the dictionary_learning repo does.) When you scale x by a constant α, this entire expression scales by a factor of α, so that the balance between reconstruction and sparsity remains the same. (On the other hand, this will mean that you might need to scale the learning rate by 1/α, so perhaps it would be reasonable to divide through this expression by ∥x∥2? I’m not sure.)
Also, one other thing I noticed: something which we both did was to compute MSE by taking the mean over the squared difference over the batch dimension and the activation dimension. But this isn’t quite what MSE usually means; really we should be summing over the activation dimension and taking the mean over the batch dimension. That means that both of our MSEs are erroneously divided by a factor of the hidden dimension (768 for you and 512 for me).
This constant factor isn’t a huge deal, but it does mean that:
The MSE losses that we’re reporting are deceptively low, at least for the usual interpretation of “mean squared error”
If we decide to fix this, we’ll need to both scale up our L1 regularization penalty by a factor of the hidden dimension (and maybe also scale down the learning rate).
This is a good lesson on how MSE isn’t naturally easy to interpret and we should maybe just be reporting percent variance explained. But if we are going to report MSE (which I have been), I think we should probably report it according to the usual definition.
I tried replicating your statistics using my own evaluation code (in evaluation.py here). I pseudo-randomly chose layer 1 and layer 7. Sadly, my results look rather different from yours:
Places where our metrics agree: L1 and L0.
Places where our metrics disagree, but probably for a relatively benign reason:
Percent variance explained: my numbers are slightly lower than yours, and from a brief skim of your code I think it’s because you’re calculating variance slightly incorrectly: you’re not subtracting off the activation’s mean before doing .pow(2).sum(-1). This will slightly overestimate the variance of the original activations, so probably also overestimate percent variance explained.
Percent alive: my numbers are slightly lower than yours, and this is probably because I determined whether neurons are alive on a (somewhat small) batch of 8192 tokens. So my number is probably an underestimate and yours is correct.
Our metrics disagree strongly on CE reconstructed, and this is a bit alarming. It means that either you have a bug which significantly underestimates reconstructed CE loss, or I have a bug which significantly overestimates it. I think I’m 50⁄50 on which it is. Note that according to my stats, your MSE loss is kinda bad, which would suggest that you should also have high CE reconstructed (especially when working with residual stream dictionaries! (in contrast to e.g. MLP dictionaries which are much more forgiving)).
Spitballing a possible cause: when computing CE loss, did you exclude padding tokens? If not, then it’s possible that many of the tokens on which you’re computing CE are padding tokens, which is artificially making your CE look extremely good.
Here is my code. You’ll need to
pip install nnsight
before running it. Many thanks to Caden Juang for implementing the UnifiedTransformer functionality in nnsight, which is a crazy Frankenstein marriage of nnsight and transformer_lens; it would have been very hard for me to attempt this replication without this feature.Oh no. I’ll look into this and get back to you shortly. One obvious candidate is that I was reporting CE for some batch at the end of training that was very small and so the statistics likely had high variance and the last datapoint may have been fairly low. In retrospect I should have explicitly recalculated this again post training. However, I’ll take a deeper dive now to see what’s up.
I’ve run some of the SAE’s through more thorough eval code this morning (getting variance explained with the centring and calculating mean CE losses with more batches). As far as I can tell the CE loss is not that high at all and the MSE loss is quite low. I’m wondering whether you might be using the wrong hooks? These are resid_pre so layer 0 is just the embeddings and layer 1 is after the first transformer block and so on. One other possibility is that you are using a different dataset? I trained these SAEs on OpenWebText. I don’t much padding at all, that might be a big difference too. I’m curious to get to the bottom of this.
One sanity check I’ve done is just sampling from the model when using the SAE to reconstruct activations and it seems to be about as good, which I think rules out CE loss in the ranges you quote above.
For percent alive neurons a batch size of 8192 would be far too few to estimate dead neurons (since many neurons have a feature sparsity < 10**-3.
You’re absolutely right about missing the centreing in percent variance explained. I’ve estimated variance explained again for the same layers and get very similar results to what I had originally. I’ll make some updates to my code to produce CE score metrics that have less variance in the future at the cost of slightly more train time.
If we don’t find a simple answer I’m happy to run some more experiments but I’d guess an 80% probability that there’s a simple bug which would explain the difference in what you get. Rank order of most likely: Using the wrong activations, using datapoints with lots of padding, using a different dataset (I tried the pile and it wasn’t that bad either).
In the notebook I link in my original comment, I check that the activations I get out of nnsight are the same as the activations that come from transformer_lens. Together with the fact that our sparsity statistics broadly align, I’m guessing that the issue isn’t that I’m extracting different activations than you are.
Repeating my replication attempt with data from OpenWebText, I get this:
Broadly speaking, same story as above, except that the MSE losses look better (still not great), and that the CE reconstructed looks very bad for layer 1.
Seems like there was a typo here—what do you mean?
Logan Riggs reports that he tried to replicate your results and got something more similar to you. I think Logan is making decisions about padding and tokenization more like the decisions you make, so it’s possible that the difference is down to something around padding and tokenization.
Possible next steps:
Can you report your MSE Losses (instead of just variance explained)?
Can you try to evaluate the residual stream dictionaries in the 5_32768 set released here? If you get CE reconstructed much better than mine, then it means that we’re computing CE reconstructed in different ways, where your way consistently reports better numbers. If you get CE reconstructed much worse than mine, then it might mean that there’s a translation error between our codebases (e.g. using different activations).
Another sanity check: when you compute CE loss using the same code that you use when computing CE loss when activations are reconstructed by the autoencoders, but instead of actually using the autoencoder you just plug the correct activations back in, do you get the same answer (~3.3) as when you evaluate CE loss normally?
MSE Losses were in the WandB report (screenshot below).
I’ve loaded in your weights for one SAE and I get very bad performance (high L0, high L1, and bad MSE Loss) at first.
It turns out that this is because my forward pass uses a tied decoder bias which is subtracted from the initial activations and added as part of the decoder forward pass. AFAICT, you don’t do this.
To verify this, I added the decoder bias to the activations of your SAE prior to running a forward pass with my code (to effectively remove the decoder bias subtraction from my method) and got reasonable results.
I’ve screenshotted the Towards Monosemanticity results which describes the tied decoder bias below as well.
I’d be pretty interested in knowing if my SAEs seem good now based on your evals :) Hopefully this was the only issue.
My SAEs also have a tied decoder bias which is subtracted from the original activations. Here’s the relevant code in
dictionary.py
Note that I checked that our SAEs have the same input-output behavior in my linked colab notebook. I think I’m a bit confused why subtracting off the decoder bias had to be done explicitly in your code—maybe you used
dictionary.encoder
anddictionary.decoder
instead ofdictionary.encode
anddictionary.decode
? (Sorry, I know this is confusing.) ETA: Simple things I tried based on the hypothesis “one of us needs to shift our inputs by +/- the decoder bias” only made things worse, so I’m pretty sure that you had just initially converted my dictionaries into your infrastructure in a way that messed up the initial decoder bias, and therefore had to hand-correct it.I note that the MSE Loss you reported for my dictionary actually is noticeably better than any of the MSE losses I reported for my residual stream dictionaries! Which layer was this? Seems like something to dig into.
Ahhh I see. Sorry I was way too hasty to jump at this as the explanation. Your code does use the tied decoder bias (and yeah, it was a little harder to read because of how your module is structured). It is strange how assuming that bug seemed to help on some of the SAEs but I ran my evals over all your residual stream SAE’s and it only worked for some / not others and certainly didn’t seem like a good explanation after I’d run it on more than one.
I’ve been talking to Logan Riggs who says he was able to load in my SAEs and saw fairly similar reconstruction performance to to me but that outside of the context length of 128 tokens, performance markedly decreases. He also mentioned your eval code uses very long prompts whereas mine limits to 128 tokens so this may be the main cause of the difference. Logan mentioned you had discussed this with him so I’m guessing you’ve got more details on this than I have? I’ll build some evals specifically to look at this in the future I think.
Scientifically, I am fairly surprised about the token length effect and want to try training on activations from much longer context sizes now. I have noticed (anecdotally) that the number of features I get sometimes increases over the prompt so an SAE trained on activations from shorter prompts are plausibly going to have a much easier time balancing reconstruction and sparsity, which might explain the generally lower MSE / higher reconstruction. Though we shouldn’t really compare between models and with different levels of sparsity as we’re likely to be at different locations on the pareto frontier.
One final note is that I’m excited to see whether performance on the first 128 tokens actually improves in SAEs trained on activations from > 128 token forward passes (since maybe the SAE becomes better in general).
Yep, as you say, @Logan Riggs figured out what’s going on here: you evaluated your reconstruction loss on contexts of length 128, whereas I evaluated on contexts of arbitrary length. When I restrict to context length 128, I’m able to replicate your results.
Here’s Logan’s plot for one of your dictionaries (not sure which)
and here’s my replication of Logan’s plot for your layer 1 dictionary
Interestingly, this does not happen for my dictionaries! Here’s the same plot but for my layer 1 residual stream output dictionary for pythia-70m-deduped
(Note that all three plots have a different y-axis scale.)
Why the difference? I’m not really sure. Two guesses:
The model: GPT2-small uses learned positional embeddings whereas Pythia models use rotary embeddings
The training: I train my autoencoders on variable-length sequences up to length 128; left padding is used to pad shorter sequences up to length 128. Maybe this makes a difference somehow.
In terms of standardization of which metrics to report, I’m torn. On one hand, for the task your dictionaries were trained on (reconstruction activations taken from length 128 sequences), they’re performing well and this should be reflected in the metrics. On the other hand, people should be aware that if they just plug your autoencoders into GPT2-small and start doing inference on inputs found in the wild, things will go off the rails pretty quickly. Maybe the answer is that CE diff should be reported both for sequences of the same length used in training and for arbitrary-length sequences?
The fact that Pythia generalizes to longer sequences but GPT-2 doesn’t isn’t very surprising to me—getting long context generalization to work is a key motivation for rotary, e.g. the original paper https://arxiv.org/abs/2104.09864
I think the learned positional embeddings combined with training on only short sequences is likely to be the issue. Changing either would suffice.
Makes sense. Will set off some runs with longer context sizes and track this in the future.
This comment is about why we were getting different MSE numbers. The answer is (mostly) benign—a matter of different scale factors. My parallel comment, which discusses why we were getting different CE diff numbers is the more important one.
When you compute MSE loss between some activations x and their reconstruction ^x, you divide by variance of x, as estimated from the data in a batch. I’ll note that this doesn’t seem like a great choice to me. Looking at the resulting training loss:
∥x−^x∥22/Var(x)+λ∥f∥1
where f is the encoding of x by the autoencoder and λ is the L1 regularization constant, we see that if you scale x by some constant α, this will have no effect on the first term, but will scale the second term by α. So if activations generically become larger in later layers, this will mean that the sparsity term becomes automatically more strongly weighted.
I think a more principled choice would be something like
∥x−^x∥2+λ∥f∥1
where we’re no longer normalizing by the variance, and are also using sqrt(MSE) instead of MSE. (This is what the
dictionary_learning
repo does.) When you scale x by a constant α, this entire expression scales by a factor of α, so that the balance between reconstruction and sparsity remains the same. (On the other hand, this will mean that you might need to scale the learning rate by 1/α, so perhaps it would be reasonable to divide through this expression by ∥x∥2? I’m not sure.)Also, one other thing I noticed: something which we both did was to compute MSE by taking the mean over the squared difference over the batch dimension and the activation dimension. But this isn’t quite what MSE usually means; really we should be summing over the activation dimension and taking the mean over the batch dimension. That means that both of our MSEs are erroneously divided by a factor of the hidden dimension (768 for you and 512 for me).
This constant factor isn’t a huge deal, but it does mean that:
The MSE losses that we’re reporting are deceptively low, at least for the usual interpretation of “mean squared error”
If we decide to fix this, we’ll need to both scale up our L1 regularization penalty by a factor of the hidden dimension (and maybe also scale down the learning rate).
This is a good lesson on how MSE isn’t naturally easy to interpret and we should maybe just be reporting percent variance explained. But if we are going to report MSE (which I have been), I think we should probably report it according to the usual definition.