Sparse Autoencoders Find Highly Interpretable Directions in Language Models
This is a linkpost for Sparse Autoencoders Find Highly Interpretable Directions in Language Models
We use a scalable and unsupervised method called Sparse Autoencoders to find interpretable, monosemantic features in real LLMs (Pythia-70M/410M) for both residual stream and MLPs. We showcase monosemantic features, feature replacement for Indirect Object Identification (IOI), and use OpenAI’s automatic interpretation protocol to demonstrate a significant improvement in interpretability.
Paper Overview
Sparse Autoencoders & Superposition
To reverse engineer a neural network, we’d like to first break it down into smaller units (features) that can be analysed in isolation. Using individual neurons as these units can be useful but neurons are often polysemantic, activating for several unrelated types of feature so just looking at neurons is insufficient. Also, for some types of network activations, like the residual stream of a transformer, there is little reason to expect features to align with the neuron basis so we don’t even have a good place to start.
Toy Models of Superposition investigates why polysemanticity might arise and hypothesise that it may result from models learning more distinct features than there are dimensions in the layer, taking advantage of the fact that features are sparse, each one only being active a small proportion of the time. This suggests that we may be able to recover the network’s features by finding a set of directions in activation space such that each activation vector can be reconstructed from a sparse linear combinations of these directions.
We attempt to reconstruct these hypothesised network features by training linear autoencoders on model activation vectors. We use a sparsity penalty on the embedding, and tied weights between the encoder and decoder, training the models on 10M to 50M activation vectors each. For more detail on the methods used, see the paper.
Automatic Interpretation
We use the same automatic interpretation technique that OpenAI used to interpret the neurons in GPT2 to analyse our features, as well as alternative methods of decomposition. This was demonstrated in a previous post but we now extend these results across the all 6 layers in Pythia-70M, showing a clear improvement over all baselines in all but the final layers. Case studies later in the paper suggest that the features are still meaningful in these later layers but that automatic interpretation struggles to perform well.
IOI Feature Identification
We are able to use less-than-rank one ablations to precisely edit activations to restore uncorrupted behaviour on the IOI task. With normal activation patching, patches occur at a module-wide level, while here we perform interventions of the form
where is the embedding of the corrupted datapoint, is the set of patched features, and and are the activations of feature on the clean and corrupted datapoint respectively.
We show that our features are able to better able to precisely reconstruct the data than other activation decomposition methods (like PCA), and moreover that the finegrainedness of our edits increases with dictionary sparsity. Unfortunately, as our autoencoders are not able to perfectly reconstruct the data, they have a positive minumum KL-divergence from the base model, while PCA does not.
Dictionary Features are Highly Monosemantic & Causal
(Left) Histogram of activations for a specific dictionary feature. The majority of activations are for apostrophe (in blue), where the y-axis the is number of datapoints that activate in that bin. (Right) Histogram of the drop in logits (ie how much the LLM predicts a specific token) when ablating this dictionary feature direction.
This is in contrast to the residual stream basis:
which appears highly polysemantic (ie many semantic meanings). More examples can be found in Appendix E. We’ve found many context-neurons (e.g. [medical/Biology/Stack Exchange/German]-context), with some shown in a previous post, so this is an existence proof against concerns that this method only finds token-level features.
Automatic Circuit Discovery
The previous section was on a dictionary’s feature relationship to the input tokens and it’s effect on the logits. We can also see the relationship between features themselves.
Layer 5 is the last layer in Pythia-70M, and this feature directly unembeds into various forms of the closing parenthesis. We can view the previous layers as calculating “What are all the reasons one might predict a closing parenthesis?”.
Conclusion
Sparse autoencoders are a scalable, unsupervised approach to disentangling language model network features from superposition. We have demonstrated that the dictionary features they learn are more interpretable by autointerpretability, are better for performing precise model steering, and are more monosemantic than comparable methods.
The ability to find these dictionary features gives us a new, fully unsupervised tool to investigate model behaviour, allows us to make targeted edits, and can be trained using a manageable amount of computing power.
An ambitious dream in the field of interpretability is enumerative safety: the ability to understand the full set of computations that a model applies. If this were achieved, it could allow us to create models for which we have strong guarantees that the model is not able to perform certain dangerous actions, such as deception or advanced bioengineering. While this is still remote, dictionary learning hopefully marks a small step towards making it possible.
In summary, sparse autoencoders bring a new tool to the interpretability and editing of language models, which we hope others can build upon. The potential for innovations and applications is vast, and we’re excited to see what happens next.
Bonus Section: Did We Find All the Features?
No.
In general, we get a reconstruction loss, and if that’s 0, than we’ve perfectly reconstructed e.g. Layer 4 with our sparse autoencoder. But what does a reconstruction loss of 0.01 mean compared to 0.0001?
We can ground this out to the difference in perplexity (a measure of prediction loss) on some dataset. This will better measure the functional equivalence (ie they have the same loss on the same data). As non-released, preliminary results, with GPT2 (small) on layer 4 on a subset of OpenWebText:
A difference in perplexity of 2.6 for training directly on KL-divergence[1] is quite small, especially for 4 months of effort between 3 main researchers. The two possibilities are
People better at maths/ML/sparse dictionary learning than us can get it to ~0-perplexity difference
A subset of features aren’t linearly-represented.
If (2) is the case, then we’ll now have a dataset of datapoints that aren’t linearly represented which we can study![2] This would show that superposition only explains a subset of features, and provide concrete counterexamples to the linear-part of the hypothesis.
We would like to give two big caveats though:
We don’t have a perfect monosemanticity metric, so even if we have 0-reconstruction loss, we can’t claim each feature is monosemantic, although a lower sparsity is partial evidence for that.
What if every 1000 features decreases the remaining reconstruction loss by half, so we’re really infinity features away from perfect reconstruction?
Come Work With Us
We are currently discussing research in the #unsupervised-interp channel (under Interpretabilty) in the EleutherAI Discord server. If you’re a researcher and have directions you’d like to apply sparse auteoncoders to, feel free to message Logan on Discord (loganriggs) or LW & we can chat!
For specific questions on sections (we’re all on discord as well):
1. Hoagy- autoninterp & MLP results
2. Aidan—IOI Feature Identification
3. Logan—Monosemantic features & Auto-circuits
- ^
KL-divergence is calculated by getting the original LLM’s output, then reconstructing e.g. layer 4 w/ the autoencoder to get a different output, then finding the KL-div between these two outputs. In practice, we found training on KL-div & reconstruction (and sparsity) to converge to lower perplexity.
- ^
These datapoints can be found by finding datapoints with the highest perplexity-difference.
- Comparing Anthropic’s Dictionary Learning to Ours by 7 Oct 2023 23:30 UTC; 137 points) (
- Sparse Coding, for Mechanistic Interpretability and Activation Engineering by 23 Sep 2023 19:16 UTC; 42 points) (
- Linear encoding of character-level information in GPT-J token embeddings by 10 Nov 2023 22:19 UTC; 34 points) (
- Machine Unlearning Evaluations as Interpretability Benchmarks by 23 Oct 2023 16:33 UTC; 33 points) (
- Normalizing Sparse Autoencoders by 8 Apr 2024 6:17 UTC; 21 points) (
- Early Experiments in Reward Model Interpretation Using Sparse Autoencoders by 3 Oct 2023 7:45 UTC; 17 points) (
- Classifying representations of sparse autoencoders (SAEs) by 17 Nov 2023 13:54 UTC; 15 points) (
- Neuronpedia—AI Safety Game by 16 Oct 2023 9:35 UTC; 9 points) (EA Forum;
- 8 Nov 2023 20:01 UTC; 4 points) 's comment on Box inversion revisited by (
- 27 Sep 2023 17:20 UTC; 3 points) 's comment on Sparse Coding, for Mechanistic Interpretability and Activation Engineering by (
- 22 Oct 2023 20:12 UTC; 1 point) 's comment on Features and Adversaries in MemoryDT by (
Did you try searching for similar ideas to your work in the broader academic literature? There seems to be lots of closely related work that you’d find interesting. For example:
Elite BackProp: Training Sparse Interpretable Neurons. They train CNNs to have “class-wise activation sparsity.” They claim their method achieves “high degrees of activation sparsity with no accuracy loss” and “can assist in understanding the reasoning behind a CNN.”
Accelerating Convolutional Neural Networks via Activation Map Compression. They “propose a three-stage compression and acceleration pipeline that sparsifies, quantizes, and entropy encodes activation maps of Convolutional Neural Networks.” The sparsification step adds an L1 penalty to the activations in the network, which they do at finetuning time. The work just examines accuracy, not interpretability.
Enhancing Adversarial Defense by k-Winners-Take-All. Proposes the k-Winners-Take-All activation function, which keeps only the k largest activations and sets all other activations to 0. This is a drop-in replacement during neural network training, and they find it improves adversarial robustness in image classification. How Can We Be So Dense? The Benefits of Using Highly Sparse Representations also uses the k-Winners-Take-All activation function, among other sparsification techniques.
The Neural LASSO: Local Linear Sparsity for Interpretable Explanations. Adds an L1 penalty to the gradient wrt the input. The intuition is to make the final output have a “sparse local explanation” (where “local explanation” = input gradient)
Adaptively Sparse Transformers. They replace softmax with α-entmax, “a differentiable generalization of softmax that allows low-scoring words to receive precisely zero weight.” They claim “improve[d] interpretability and [attention] head diversity” and also that “at no cost in accuracy, sparsity in attention heads helps to uncover different head specializations.”
Interpretable Neural Predictions with Differentiable Binary Variables. They train two neural networks. One “selects a rationale (i.e. a short and informative part of the input text)”, and the other “classifies… from the words in the rationale alone.”
I ask because your paper doesn’t seem to have a related works section, and most of your citations in the intro are from other safety research teams (eg Anthropic, OpenAI, CAIS, and Redwood.)
Hi Scott, thanks for this!
Yes I did do a fair bit of literature searching (though maybe not enough tbf) but very focused on sparse coding and approaches to learning decompositions of model activation spaces rather than approaches to learning models which are monosemantic by default which I’ve never had much confidence in, and it seems that there’s not a huge amount beyond Yun et al’s work, at least as far as I’ve seen.
Still though, I’ve not seen almost any of these which suggests a big hole in my knowledge, and in the paper I’ll go through and add a lot more background to attempts to make more interpretable models.
Awesome work! I like the autoencoder approach a lot.
Cool work! I really like the ACDC on the parenthesis feature part, I’d love to see more work like that, and work digging into exactly how things compose with each other in terms of the weights.
I’ve had trouble figuring out a weight-based approach due to the non-linearity and would appreciate your thoughts actually.
We can learn a dictionary of features at the residual stream (R_d) & another mid-MLP (MLP_d), but you can’t straightfowardly multiply the features from R_d with W_in, and find the matching features in MLP_d due to the nonlinearity, AFAIK.
I do think you could find Residual features that are sufficient to activate the MLP features[1], but not all linear combinations from just the weights.
Using a dataset-based method, you could find causal features in practice (the ACDC portion of the paper was a first attempt at that), and would be interested in an activation*gradient method here (though I’m largely ignorant).
Specifically, I think you should scale the residual stream activations by their in-distribution max-activating examples.
Did you ever try out independent component analysis? There’s a scikit-learn implementation even. If you haven’t, I’m strongly tempted to throw an undergrad at it (in a RL setting where it makes sense to look for features that are coherent across time).
EDIT: Nevermind, it’s in the paper. And also I guess in the figure if I was paying closer attention :P
Hi Charlie, yep it’s in the paper—but I should say that we did not find a working CUDA-compatible version and used the scikit version you mention. This meant that the data volumes used are somewhat limited—still on the order of a million examples but 10-50x less than went into the autoencoders.
It’s not clear whether the extra data would provide much signal since it can’t learn an overcomplete basis and so has no way of learning rare features but it might be able to outperform our ICA baseline presented here, so if you wanted to give someone a project of making that available, I’d be interested to see it!
The LessWrong Review runs every year to select the posts that have most stood the test of time. This post is not yet eligible for review, but will be at the end of 2024. The top fifty or so posts are featured prominently on the site throughout the year.
Hopefully, the review is better than karma at judging enduring value. If we have accurate prediction markets on the review results, maybe we can have better incentives on LessWrong today. Will this post make the top fifty?