MSE Losses were in the WandB report (screenshot below).
I’ve loaded in your weights for one SAE and I get very bad performance (high L0, high L1, and bad MSE Loss) at first.
It turns out that this is because my forward pass uses a tied decoder bias which is subtracted from the initial activations and added as part of the decoder forward pass. AFAICT, you don’t do this.
To verify this, I added the decoder bias to the activations of your SAE prior to running a forward pass with my code (to effectively remove the decoder bias subtraction from my method) and got reasonable results.
I’ve screenshotted the Towards Monosemanticity results which describes the tied decoder bias below as well.
I’d be pretty interested in knowing if my SAEs seem good now based on your evals :) Hopefully this was the only issue.
Note that I checked that our SAEs have the same input-output behavior in my linked colab notebook. I think I’m a bit confused why subtracting off the decoder bias had to be done explicitly in your code—maybe you used dictionary.encoder and dictionary.decoder instead of dictionary.encode and dictionary.decode? (Sorry, I know this is confusing.) ETA: Simple things I tried based on the hypothesis “one of us needs to shift our inputs by +/- the decoder bias” only made things worse, so I’m pretty sure that you had just initially converted my dictionaries into your infrastructure in a way that messed up the initial decoder bias, and therefore had to hand-correct it.
I note that the MSE Loss you reported for my dictionary actually is noticeably better than any of the MSE losses I reported for my residual stream dictionaries! Which layer was this? Seems like something to dig into.
Ahhh I see. Sorry I was way too hasty to jump at this as the explanation. Your code does use the tied decoder bias (and yeah, it was a little harder to read because of how your module is structured). It is strange how assuming that bug seemed to help on some of the SAEs but I ran my evals over all your residual stream SAE’s and it only worked for some / not others and certainly didn’t seem like a good explanation after I’d run it on more than one.
I’ve been talking to Logan Riggs who says he was able to load in my SAEs and saw fairly similar reconstruction performance to to me but that outside of the context length of 128 tokens, performance markedly decreases. He also mentioned your eval code uses very long prompts whereas mine limits to 128 tokens so this may be the main cause of the difference. Logan mentioned you had discussed this with him so I’m guessing you’ve got more details on this than I have? I’ll build some evals specifically to look at this in the future I think.
Scientifically, I am fairly surprised about the token length effect and want to try training on activations from much longer context sizes now. I have noticed (anecdotally) that the number of features I get sometimes increases over the prompt so an SAE trained on activations from shorter prompts are plausibly going to have a much easier time balancing reconstruction and sparsity, which might explain the generally lower MSE / higher reconstruction. Though we shouldn’t really compare between models and with different levels of sparsity as we’re likely to be at different locations on the pareto frontier.
One final note is that I’m excited to see whether performance on the first 128 tokens actually improves in SAEs trained on activations from > 128 token forward passes (since maybe the SAE becomes better in general).
Yep, as you say, @Logan Riggs figured out what’s going on here: you evaluated your reconstruction loss on contexts of length 128, whereas I evaluated on contexts of arbitrary length. When I restrict to context length 128, I’m able to replicate your results.
Here’s Logan’s plot for one of your dictionaries (not sure which)
and here’s my replication of Logan’s plot for your layer 1 dictionary
Interestingly, this does not happen for my dictionaries! Here’s the same plot but for my layer 1 residual stream output dictionary for pythia-70m-deduped
(Note that all three plots have a different y-axis scale.)
Why the difference? I’m not really sure. Two guesses:
The model: GPT2-small uses learned positional embeddings whereas Pythia models use rotary embeddings
The training: I train my autoencoders on variable-length sequences up to length 128; left padding is used to pad shorter sequences up to length 128. Maybe this makes a difference somehow.
In terms of standardization of which metrics to report, I’m torn. On one hand, for the task your dictionaries were trained on (reconstruction activations taken from length 128 sequences), they’re performing well and this should be reflected in the metrics. On the other hand, people should be aware that if they just plug your autoencoders into GPT2-small and start doing inference on inputs found in the wild, things will go off the rails pretty quickly. Maybe the answer is that CE diff should be reported both for sequences of the same length used in training and for arbitrary-length sequences?
The fact that Pythia generalizes to longer sequences but GPT-2 doesn’t isn’t very surprising to me—getting long context generalization to work is a key motivation for rotary, e.g. the original paper https://arxiv.org/abs/2104.09864
This comment is about why we were getting different MSE numbers. The answer is (mostly) benign—a matter of different scale factors. My parallel comment, which discusses why we were getting different CE diff numbers is the more important one.
When you compute MSE loss between some activations x and their reconstruction ^x, you divide by variance of x, as estimated from the data in a batch. I’ll note that this doesn’t seem like a great choice to me. Looking at the resulting training loss:
∥x−^x∥22/Var(x)+λ∥f∥1
where f is the encoding of x by the autoencoder and λ is the L1 regularization constant, we see that if you scale x by some constant α, this will have no effect on the first term, but will scale the second term by α. So if activations generically become larger in later layers, this will mean that the sparsity term becomes automatically more strongly weighted.
I think a more principled choice would be something like
∥x−^x∥2+λ∥f∥1
where we’re no longer normalizing by the variance, and are also using sqrt(MSE) instead of MSE. (This is what the dictionary_learning repo does.) When you scale x by a constant α, this entire expression scales by a factor of α, so that the balance between reconstruction and sparsity remains the same. (On the other hand, this will mean that you might need to scale the learning rate by 1/α, so perhaps it would be reasonable to divide through this expression by ∥x∥2? I’m not sure.)
Also, one other thing I noticed: something which we both did was to compute MSE by taking the mean over the squared difference over the batch dimension and the activation dimension. But this isn’t quite what MSE usually means; really we should be summing over the activation dimension and taking the mean over the batch dimension. That means that both of our MSEs are erroneously divided by a factor of the hidden dimension (768 for you and 512 for me).
This constant factor isn’t a huge deal, but it does mean that:
The MSE losses that we’re reporting are deceptively low, at least for the usual interpretation of “mean squared error”
If we decide to fix this, we’ll need to both scale up our L1 regularization penalty by a factor of the hidden dimension (and maybe also scale down the learning rate).
This is a good lesson on how MSE isn’t naturally easy to interpret and we should maybe just be reporting percent variance explained. But if we are going to report MSE (which I have been), I think we should probably report it according to the usual definition.
MSE Losses were in the WandB report (screenshot below).
I’ve loaded in your weights for one SAE and I get very bad performance (high L0, high L1, and bad MSE Loss) at first.
It turns out that this is because my forward pass uses a tied decoder bias which is subtracted from the initial activations and added as part of the decoder forward pass. AFAICT, you don’t do this.
To verify this, I added the decoder bias to the activations of your SAE prior to running a forward pass with my code (to effectively remove the decoder bias subtraction from my method) and got reasonable results.
I’ve screenshotted the Towards Monosemanticity results which describes the tied decoder bias below as well.
I’d be pretty interested in knowing if my SAEs seem good now based on your evals :) Hopefully this was the only issue.
My SAEs also have a tied decoder bias which is subtracted from the original activations. Here’s the relevant code in
dictionary.py
Note that I checked that our SAEs have the same input-output behavior in my linked colab notebook. I think I’m a bit confused why subtracting off the decoder bias had to be done explicitly in your code—maybe you used
dictionary.encoder
anddictionary.decoder
instead ofdictionary.encode
anddictionary.decode
? (Sorry, I know this is confusing.) ETA: Simple things I tried based on the hypothesis “one of us needs to shift our inputs by +/- the decoder bias” only made things worse, so I’m pretty sure that you had just initially converted my dictionaries into your infrastructure in a way that messed up the initial decoder bias, and therefore had to hand-correct it.I note that the MSE Loss you reported for my dictionary actually is noticeably better than any of the MSE losses I reported for my residual stream dictionaries! Which layer was this? Seems like something to dig into.
Ahhh I see. Sorry I was way too hasty to jump at this as the explanation. Your code does use the tied decoder bias (and yeah, it was a little harder to read because of how your module is structured). It is strange how assuming that bug seemed to help on some of the SAEs but I ran my evals over all your residual stream SAE’s and it only worked for some / not others and certainly didn’t seem like a good explanation after I’d run it on more than one.
I’ve been talking to Logan Riggs who says he was able to load in my SAEs and saw fairly similar reconstruction performance to to me but that outside of the context length of 128 tokens, performance markedly decreases. He also mentioned your eval code uses very long prompts whereas mine limits to 128 tokens so this may be the main cause of the difference. Logan mentioned you had discussed this with him so I’m guessing you’ve got more details on this than I have? I’ll build some evals specifically to look at this in the future I think.
Scientifically, I am fairly surprised about the token length effect and want to try training on activations from much longer context sizes now. I have noticed (anecdotally) that the number of features I get sometimes increases over the prompt so an SAE trained on activations from shorter prompts are plausibly going to have a much easier time balancing reconstruction and sparsity, which might explain the generally lower MSE / higher reconstruction. Though we shouldn’t really compare between models and with different levels of sparsity as we’re likely to be at different locations on the pareto frontier.
One final note is that I’m excited to see whether performance on the first 128 tokens actually improves in SAEs trained on activations from > 128 token forward passes (since maybe the SAE becomes better in general).
Yep, as you say, @Logan Riggs figured out what’s going on here: you evaluated your reconstruction loss on contexts of length 128, whereas I evaluated on contexts of arbitrary length. When I restrict to context length 128, I’m able to replicate your results.
Here’s Logan’s plot for one of your dictionaries (not sure which)
and here’s my replication of Logan’s plot for your layer 1 dictionary
Interestingly, this does not happen for my dictionaries! Here’s the same plot but for my layer 1 residual stream output dictionary for pythia-70m-deduped
(Note that all three plots have a different y-axis scale.)
Why the difference? I’m not really sure. Two guesses:
The model: GPT2-small uses learned positional embeddings whereas Pythia models use rotary embeddings
The training: I train my autoencoders on variable-length sequences up to length 128; left padding is used to pad shorter sequences up to length 128. Maybe this makes a difference somehow.
In terms of standardization of which metrics to report, I’m torn. On one hand, for the task your dictionaries were trained on (reconstruction activations taken from length 128 sequences), they’re performing well and this should be reflected in the metrics. On the other hand, people should be aware that if they just plug your autoencoders into GPT2-small and start doing inference on inputs found in the wild, things will go off the rails pretty quickly. Maybe the answer is that CE diff should be reported both for sequences of the same length used in training and for arbitrary-length sequences?
The fact that Pythia generalizes to longer sequences but GPT-2 doesn’t isn’t very surprising to me—getting long context generalization to work is a key motivation for rotary, e.g. the original paper https://arxiv.org/abs/2104.09864
I think the learned positional embeddings combined with training on only short sequences is likely to be the issue. Changing either would suffice.
Makes sense. Will set off some runs with longer context sizes and track this in the future.
This comment is about why we were getting different MSE numbers. The answer is (mostly) benign—a matter of different scale factors. My parallel comment, which discusses why we were getting different CE diff numbers is the more important one.
When you compute MSE loss between some activations x and their reconstruction ^x, you divide by variance of x, as estimated from the data in a batch. I’ll note that this doesn’t seem like a great choice to me. Looking at the resulting training loss:
∥x−^x∥22/Var(x)+λ∥f∥1
where f is the encoding of x by the autoencoder and λ is the L1 regularization constant, we see that if you scale x by some constant α, this will have no effect on the first term, but will scale the second term by α. So if activations generically become larger in later layers, this will mean that the sparsity term becomes automatically more strongly weighted.
I think a more principled choice would be something like
∥x−^x∥2+λ∥f∥1
where we’re no longer normalizing by the variance, and are also using sqrt(MSE) instead of MSE. (This is what the
dictionary_learning
repo does.) When you scale x by a constant α, this entire expression scales by a factor of α, so that the balance between reconstruction and sparsity remains the same. (On the other hand, this will mean that you might need to scale the learning rate by 1/α, so perhaps it would be reasonable to divide through this expression by ∥x∥2? I’m not sure.)Also, one other thing I noticed: something which we both did was to compute MSE by taking the mean over the squared difference over the batch dimension and the activation dimension. But this isn’t quite what MSE usually means; really we should be summing over the activation dimension and taking the mean over the batch dimension. That means that both of our MSEs are erroneously divided by a factor of the hidden dimension (768 for you and 512 for me).
This constant factor isn’t a huge deal, but it does mean that:
The MSE losses that we’re reporting are deceptively low, at least for the usual interpretation of “mean squared error”
If we decide to fix this, we’ll need to both scale up our L1 regularization penalty by a factor of the hidden dimension (and maybe also scale down the learning rate).
This is a good lesson on how MSE isn’t naturally easy to interpret and we should maybe just be reporting percent variance explained. But if we are going to report MSE (which I have been), I think we should probably report it according to the usual definition.