Sparse Autoencoders: Future Work

Mostly my own writing, except for the ‘Better Training Methods’ section which was written by @Aidan Ewart.

We made a lot of progress in 4 months working on Sparse Autoencoders, an unsupervised method to scalably find monosemantic features in LLMs, but there’s still plenty of work to do. Below I (Logan) give both research ideas, as well as my current, half-baked thoughts on how to pursue them.

Find All the Circuits!

  1. Truth/​Deception/​Sycophancy/​Train-Test distinction/​[In-context Learning/​internal Optimization]

    1. Find features relevant for these tasks. Do they generalize better than baselines?

    2. For internal optimization, can we narrow this down to a circuit (using something like causal scrubbing) and retarget the search?

  2. Understand RLHF

    1. Find features for preference/​reward models that make the reward large or very negative.

    2. Compare features of models before & after RLHF

  3. Adversarial Attacks

    1. What features activate on adversarial attacks? What features feed into those?

    2. Develop adversarial attacks, but only search over dictionary features

  4. Circuits Across Time

    1. Using a model w/​ lots of checkpoints like Pythia, we can see feature & circuit formation across time given datapoints.

  5. Circuits Across Scale

    1. Pythia models are trained on the same data, in the same order but range in model sizes from 70M params to 13B.

  6. Turn LLMs into code

    1. Link to very rough draft of the idea I (Logan) wrote in two days

  7. Mechanistic Anomaly Detection

    1. If distribution X has features A,B,C activate, and distribution Y has features B,C,D, you may be able to use this discrete property to get a better ROC curve than strictly continuous methods.

    2. How do the different operationalizations of distance between discrete features compare against each other?

  8. Activation Engineering

    1. Use feature directions found by the dictionary instead of examples. I predict this will generalize better, but would be good to compare against current methods

    2. One open problem is which token in the sequence do you add the vector to. Maybe it makes sense to only add the [female] direction to tokens that are [names]. Dictionary features in previous layers may help you automatically pick the right type e.g. a feature that activates on [names].

  9. Fun Stuff

    1. Othello/​Chess/​Motor commands—Find features that relate to actions that a model is able to do. Can we find a corner piece feature, a knight feature, a “move here” feature?

There are three ways to find features AFAIK:
1. Which input tokens activate it?

2. What output logits are causally downstream from it?

3. Which intermediate features cause it/​are caused by it?

1) Input Tokens

When finding the input tokens, you may run into outlier dimensions that activate highly for most tokens (predominately the first token), so you need to account for that.

2) Output Logits

For output logits, if you have a dataset task (e.g. predicting stereotypical gender), you can remove each feature one at a time, and sort by greatest effect. This also extends to substituting features between two distributions and finding the smallest substitution to go from one to the other. For example,

  1. “I’m Jane, and I’m a [female]”

  2. “I’m Dave, and I’m a [male]”

Suppose at token Jane, it activates 2 Features A & B [1,1,0] and Dave activates 2 features B & C [0,1,1]. Then we can see what is the smallest substitution between the two that makes Jane complete as ” male”. If A is the “female” feature, then ablating it (setting it to zero) will make the model set male/​female to equal probability. Adding the female feature to Dave and subtracting the male direction should make Dave complete as “female”.[1]

3) Intermediate Features

Say we’re looking at layer 5, feature 783, which activates ~10 for 20 datapoints on average. We can ablate each feature in layer 4, one at a time, and see which feature made those 20 datapoint’s activation go down the most. This generally resulted in features that make a lot of sense e.g. Feature “acronyms after (”, is effected when you ablate the previous layer’s feature for acronyms & the one for “(”. Other times, it’s generally the same feature, since this is the residual stream[2]

This can be extended to dictionaries trained on the output of MLP & Attention layers. Additionally, one could do a weight-based approach going from the residual stream to the MLP layer, which may allow predicting beforehand what a feature is by just the weights e.g. “This feature is just 0.5*(acronyms features) + 2.3*(open parentheses).

Prompt Feature Diff

If I want to understand the effect of few-shot prompts, I can take the 0-shot prompt:
”The initials of Logan Riggs are”, and see which features activate for those ~6 tokens. Then add in few-shot prompts before, and see the different features that activate for those ~6 tokens. In general, this can be applied to:
Feature diff between Features in [prompt] & Features in [prompt] given [Pre-prompt]

With examples being:

[few-shot prompts/​Chain-of-thought/​adversarial prompts/​soft prompts][prompt]

(though I don’t know how to extend this to appending “Let’s think step-by-step”)

Useful related work is Causal Scrubbing.

ACDC

Automatic Circuit DisCovery (ACDC) is a simple technique: to find what’s relevant for X, just remove everything upstream of it one at a time and see what breaks. Then recursively apply it. We do a similar technique in our paper, but only on the residual stream. Dictionaries (the decoder part of autoencoders) can also be trained on the output of MLP & attention units. We’ve in fact done it before and it appears quite interpretable!

We can apply this technique to connect features found in the residual stream to the MLP & attn units. Ideally, we could do a more weight-based method, such as connecting the features learned in the residual stream to the MLP. This may straightforwardly work going from the residual stream to the MLP_out dictionary. If not, it may work with dictionaries trained on the neurons of MLP (ie the activations post non-linearity).

For attention units, I have a half-baked thought of connecting residual stream directions at one layer w/​ another layer (or Attn_out) using the QK & OV circuits for a given attention head, but haven’t thought very much about this.

Better Sparse Autoencoders

I think we are quite close to finding all the features for one layer in GPT2 small. Perfecting this will help find more accurate and predictive circuits. This includes driving reconstruction & perplexity-difference down, better training methods, and better, less-Goodhart-able interpretability methods.

Reconstruction, Sparsity, & Perplexity-Diff

Reconstruction loss—How well the autoencoder reconstructs e.g. Layer 6 of the model.

Sparsity- How many features/​datapoint on average? (ie L0 norm on latent activation)

Perplexity-diff—When you run the LLM on a dataset, you get some prediction loss (which can be converted to perplexity). You can then run the LLM on the same dataset, but replace e.g. Layer 6 w/​ the autoencoder, and get a different prediction loss. Subtract. If these are 0, then this is strong evidence for the autoencoder being functionally equivalent to the original model.

Typically, we plot unexplained variance (ie reconstruction loss that takes into account variance) vs sparsity.

FVU vs Sparsity across training. As the model sees more data, it moves towards the lower-left corner. This is for Pythia-410M which has a residual dimension of 768. The model can achieve near 100% variance explained if using 600-800 sparsity (features/​dimension), but the features learned there are polysemantic.

, where we would want solutions in the bottom-left corner: perfectly explaining the data w/​ minimal sparsity (features/​datapoints). We have seen evidence (by hand and GPT-autointerp) that sparser solutions are more monosemantic. Until we have better interp methods, driving down these 3 metrics are a useful proxy.

One effective method, not written in our paper, is directly optimizing for minimal KL-divergence in addition to reconstruction & sparsity. This has driven perplexity-difference down, for similar sparsity, at the cost of some reconstruction loss.

Better Training Methods

In their work, Yun et al. use an iterative method using FISTA to find sparse codes for activations, and optimising the dictionary to lower MSE with respect to those codes. We used autoencoders as we think it better reflects what the model might be computing, but it is possible that methods like the one Yun et al use will result in a better dictionary.

Possible options here include using Yun et al’s method, pre-training a dictionary as an autoencoder and further optimising using FISTA, or simply using FISTA with a pre-trained dictionary to reduce MSE.

We could also find different methods of decomposing activations, using nonlinear autoencoders or VAEs with sparse priors. This is a very interesting line of work which might result in a better understanding of how transformers can represent information nonlinearly. We’ve faced convergence issues trying to train more powerful decompositional tools (both linear & not), but these can be helped by using softplus activations during training. Also, it seems that the link between sparsity and monosemanticity might break down very quickly as you apply more and more complex methods, and perhaps there is an alternative form of regularisation (instead of sparsity) which would work better for stronger autoencoders.

Better Interp Methods

How do we know we found good features? We can’t just say 0-reconstruction loss & 0 perplexity-diff, because the original model itself achieves that! (plus the identity function) That’s why we have sparsity, but is 20 features/​datapoint better than 60 features/​datapoint? How does this scale as you scale model size or layers?

It’d be good to have a clean, objective measure of interpretability. You could do a subjective measure of 10 randomly selected features, but that’s noisy!

I have some preliminary work on making a monsemanticity measure I can share shortly, but no good results yet!

Our previous proxies for “right hyperparams for feature goodness” have been from toy models, specifically MMCS (mean max cosine similarity) ie how similar features between two dictionaries are (if two dictionaries learned similar features, then these are “realer” features...maybe), and dead features. Check the toy model results for more details, both Lee’s original work & update and our open sourced replication.

Come Work With Us

We are currently discussing research in the #unsupervised-interp channel (under Interpretabilty) in the EleutherAI Discord server. If you’re a researcher and have directions you’d like to apply sparse auteoncoders to, feel free to message me on Discord (loganriggs) or LW & we can chat!

  1. ^

    Now that I write it though, I think you could just find the features that make both distributions “neutral”, and just add those directions.

  2. ^

    One can verify this by checking the cosine similarity between two features at different layers. If they have high cosine sim, then they’re pointing in very similar directions and will be decoded by future layers/​unembedded in the same way.