You can remove GPT2’s LayerNorm by fine-tuning for an hour

This work was produced at Apollo Research, based on initial research done at MATS.
Edit: arXiv version available at https://​​arxiv.org/​​abs/​​2409.13710

LayerNorm is annoying for mechanistic interpretability research (“[...] reason #78 for why interpretability researchers hate LayerNorm” – Anthropic, 2023).

Here’s a Hugging Face link to a GPT2-small model without any LayerNorm.

The final model is only slightly worse than a GPT2 with LayerNorm[1]:

DatasetOriginal GPT2Fine-tuned GPT2 with LayerNormFine-tuned GPT without LayerNorm
OpenWebText (ce_loss)3.0952.9893.014 (+0.025)
ThePile (ce_loss)2.8562.8802.926 (+0.046)
HellaSwag (accuracy)29.56%29.82%29.54%

I fine-tuned GPT2-small on OpenWebText while slowly removing its LayerNorm layers, waiting for the loss to go back down after reach removal:

Introduction

LayerNorm (LN) is a component in Transformer models that normalizes embedding vectors to have constant length; specifically it divides the embeddings by their standard deviation taken over the hidden dimension. It was originally introduced to stabilize and speed up training of models (as a replacement for batch normalization). It is active during training and inference.

The equation includes the standard deviation (std) which makes it a non-linear operation. This hinders interpretability in a variety of ways, from annoyances and inaccuracies such as

  • attributing residual stream directions to logit effects (e.g. SAE features, direct logit attribution),[2]

  • being annoying to deal with Attribution Patching, or

  • being difficult to deal with in Apollo’s LIB method.

In the Docstring circuit analysis we seriously considered whether the model might be using LN in its algorithm. This post even shows that LN can be used as the sole non-linearity to solve non-linear classification problems (see also this related work).

Recently, with progress in Sparse Dictionary Learning, agendas (e.g. this one) imagine decomposing networks into sets of sparsely connected components (SAEs, Transcoders, etc.). A core difficulty to “putting it all together” is that the interactions between different components often route through LayerNorm whose effect we do not understand.

Motivation

It would be pretty neat to have an LLM that still works (speaks English etc.) while less or no LN layers. One option would be to train a model without LN from scratch (done for tiny models, e.g. TinyModel), but this is very hard or impossible for larger models (hearsay is that you need a low learning rate and to be very careful).

Taking an existing model and removing the LN layers however seems doable if LN isn’t implementing some important computation.[3] That is, LN “does its thing” and the model has learned to “deal with it”, but it’s not irreplaceable. A reason to be optimistic is that the spread of standard deviations across different samples isn’t that large, so maybe replacing the LN-computed standard deviation with a fixed number might kinda work.

Method

I take GPT2-small, fine-tune it on OpenWebText, and remove LNs one-by-one while fine-tuning.

The only non-linear operation in a LN layer is the division by the standard deviation (std) of the embedding vectors; the remaining operations can be absorbed into later weight matrices (see the fold_ln option in TransformerLens; also discussed in this appendix). Thus I mainly focus on the std part here.

My general strategy is to “remove” an LN layer (this makes the loss go up), and then to train the model for some time (on the original training data) until the loss is back near the baseline. For this “remove” step I do the following

  • Calculate the average std on the dataset (I used a quite small sample, 16 prompts), separately for position 0 and position > 0

  • Replace the std calculation with the average std (position > 0), sometimes with special treatments for

    • Position 0: use the position 0 average instead

    • EOT tokens: When the input is an end-of-text token I also use the position 0 average std (because norms on EOT tokens tend to be larger)[4]

Whenever I do the replacement the loss jumps up, from the baseline of 3.0 up to 3.5, sometimes even around 5.0. After 10-100 iterations (learning rate 6e-4 and batch size approx. 488 as recommended here) the loss typically goes down to between 3.0 and 3.1. However, if I’m not careful and change too much at once, the loss can jump very high (around 8.0), and in those cases it usually never recovers. Thus I want to avoid making too big of a change at once.

Here’s the recipe I empirically found to work. After every step, train for 50-200 iterations or until the loss is close to baseline.

  • Fine-tune for 300 iterations without any changes (GPT2 without any fine-tuning does surprisingly bad on OpenWebText)

  • Remove ln_2 (the LN before the MLP block), replacing the std with an average value. I use the “Position 0” special treatment but not the “EOT tokens” special treatment for ln_2.

  • Remove ln_1 for the q and k vectors. As for ln_2, I only use the “Position 0” special treatment.

    • I’ve also experimented with removing these LNs one-by-one (remove in a layer, train for 10 iterations before removing it in the next layer, …). That is plausibly the strictly better strategy, but in 2 out of 3 successful fine-tunes I didn’t need to do this.

  • Remove ln_1 for the v vector. Use both the “Position 0” and “EOT tokens” special treatments.

  • Remove ln_f, the final LN. Again I only use the “Position 0” special treatment.

  • Remove the “EOT tokens” special treatment.

  • Remove the “Position 0” special treatment.

I considered scaling individual LNs down slowly (e.g. interpolate between the actual calculated std and the average std) but I never ended up needing this, and did not really explore it.

In general I observed that

  • There seems to be some mechanism where the first interventions are harder than the later ones. Whichever LN I remove first, it tends to cause a rather big loss-bump, while later removals sometimes are barely noticeable.

    • For example, in all of my runs where I remove ln_f at the end, it causes only a small increase in loss. If I remove it at the beginning it causes a much larger loss bump.

  • Removing ln_2 was usually easier than removing ln_1. This, combined with the point above, is why I remove ln_2 before ln_1.

  • Removing ln_1 for the q and k vectors was not too bad, but the v vector was much harder. I only managed to remove its LN once I implemented the “EOT tokens” special treatment.

  • I tentatively feel like there is always “1 big loss spike” somewhere – in some earlier runs I remember a big spike when removing ln_1, while in other experiments (like the screenshot at the top) the spike happens then I remove the “EOT tokens” special treatment.

    • This suggests a hypothesis along the lines of “there is something the model needs to learn to stop using LayerNorm, and once it did that the rest is easy”. This could be e.g. “damped /​ throw away the confidence neuron” or something about how EOT tokens are handled, but I don’t have a good guess yet.

Implementation

I implement everything based on the NanoGPT repository. I replace the standard deviation calculation in the LN by a fixed number (set to the average standard deviation). This number is fixed, but it is degenerate with the LN scale (self.weight) which is learnable.

std = self.average_std if std_type == "avg" else self.bos_std]
(x - x.mean(-1, keepdim=True)) / std * self.weight + self.bias

I have two different average stds, self.average_std (average of std over all tokens except for position 0) and self.bos_std (average of stds at position 0). Initially, after replacing the real LN with this “dummy LN” I use the following policy for choosing which std to divide by:

  • Use self.average_std if token position is > 0 [and token is not end-of-text (EOT)]

  • Use self.bos_std if token position is 0 [or current token us EOT]

The EOT rule in [brackets] is used only for the LN before the attention v vector. Pseudocode for a simplified version of my implementation:

x_v[eot_mask] = self.ln_1(x, std_type="eot")[eot_mask]
x_v[~eot_mask] = self.ln_1(x, std_type="avg")[~eot_mask]

I will upload the full code when I have more time next week (email me if you’d like it earlier).

Results

Well, the models train, the loss is low, and the models still speak English.

v1 model: Manually interrupt & resume training whenever loss went down enough:

v2 model (+ vanilla GPT2 fine-tuning comparison in green) with scheduled LN removals:

I’m currently training a third version employing the gradual removal technique to more of the components, and will update the Hugging Face repository (use the v1 and v2 revisions for models presented here).

Here I compare the cross entropy losses on OpenWebText and ThePile, as well as the score on the HellaSwap benchmark (using the Andrej Karpathy’s implementation). I choose two baselines; the original GPT2 model straight from Hugging Face, and a second version that I fine-tuned on OpenWebText for 1000 iterations to be more comparable to the noLN ones (which were fine-tuned on that data).

DatasetOriginal GPT2GPT2 with LNGPT2 noLN v1GPT2 noLN v2
OpenWebText (ce_loss)3.0952.9893.1393.014 (+0.025)
ThePile (ce_loss)2.8562.8803.0572.926 (+0.046)
HellaSwag (accuracy)29.56%29.82%29.17%29.54%

Some sample generations with model.generate:

GPT2:

As the last leaf fell from the tree, John realized that it was empty. He took the leaf and turned it over to his wife, who told him that it was still there and that he would have to go to the church to find it. John went to the church, and found that it was empty. He said, “I am going to the church and I am going to find the rest of the leaves, and I am going to look for them and find out where they

As the last leaf fell from the tree, John realized that the tree had been torn down. As he turned his head, the other trees started to fall. “Come on,” John said, “we’re going to get out of here!” The next tree was a wildflower. “How is it?” John asked, “do you see any other way?” “It’s a good thing,” the other trees replied.

GP2_noLN:

As the last leaf fell from the tree, John realized that he was going to have to use it. “We’ve all been there. We’ve all been there,” he said. “It’s been a long time. But, it’s good to be back. It’s good to be back.” “It’s good to be back. It’s good to be back

As the last leaf fell from the tree, John realized that a large amount of weight had been lifted from him. “I had a little panic attack. I was afraid that I could not walk,” he said. “I felt like my legs were going to break.” John has since gone back to the tree. “I have to tell you that I’m sorry I did that, but I don’t know if that will ever happen,” he said.

Residual stream norms

Previously I noticed that the norm of the residual stream looks very different at position 0 compared to the rest of the model. Does this behavior still exists in the noLN model? No. The noLN model (solid lines) does not show the different in norms we saw in the original models (dashed lines):

Discussion

Faithfulness to the original model

I expect that the new no-LayerNorm (noLN) model will not have the exact same internal mechanisms as the original model. To some extent I expect lots of similarities (as the new model just had an hour to train), but since the norms changed (see above) and the loss changed, I expect differences.

My goal with this model is more like “have a toy model almost as good as GPT2 but easier to interpret” to replace vanilla GPT2 in interpretability research. Models like GPT2 (and Pythia etc.) are useful not because they are the models we ultimately care about (GPT4, GPT5, …) but because they let us generally explore how LLM internals work. To this extent, I don’t mind if GPT2-noLN differs from GPT2.

A question for the future is whether we want to apply this LN removal method before interpreting gpt4 and other “production models”. This depends on how similar the internals are, and I am currently uncertain about this. I am primarily concerned with the earlier use-case.

Does the noLN model generalize worse?

I noticed that the noLN performance hit was worse on The Pile than OpenWebText. This might be a coincidence, but it could also suggest that removing LN hurts generalization. While LN was originally introduced for training stability purposes, it may have a side effect on generalization. I have not evaluated the models on more datasets and leave this question for future research. Edit 30th Aug 2024: On the other hand this makes the OpenWebText and ThePile losses of the no-LN model more similar than those of the original model.

Appendix

Representing the no-LayerNorm model in GPT2LMHeadModel

I replace LayerNorms with DummyLayerNorms that use a fixed std, rather than computing the actual variance of each sample. This is equivalent to removing because the remaining LN operations can be folded into the following layers (e.g. TransformerLens mostly does this; TL does not fold in the centering operations but I do). I perform this folding in for all ln_1 and ln_2 layers. Thus I obtain weights with which GPT2 can run without any LayerNorm layers (all LNs replaced by nn.identity).

To make the model available on Hugging Face without trust_remote_code I want to package it into the GPT2LMHeadModel class. Thus I want to “neuter” the LNs in GPT2LMHeadModel such that the model just works with my “LNs are identities”-weights. I do this by setting ln_eps (epsilon) to a very high value (1e12), and setting the ln weights (gamma) to a corresponding value (1e6). I set the biases to 0. This leaves the centering operation but this doesn’t matter as I also fold a centering operation into the following layers, thus the LNs can be removed without further changes.

There’s one exception to this, the final layer norm. GPT2LMHeadModel uses (a) tied embedding and unembedding weights, and (b) no unembedding bias. Thus it is impossible to fold the final LayerNorm, which includes a (diagonal) weight matrix and a bias, into the other weights here. I still “neuter” the normalizing function of the LayerNorm as above, so ln_final just represents a simple linear layer before the unembedding.

Which order to remove LayerNorms in

There’s two sources of reasons that inform which LayerNorms I want to remove furst

  • LayerNorm is there to help us train the model (stabilizing /​ speeding up training). While we are only fine-tuning, I expect that the remaining LNs still help us fine-tune the weights after removing the first LNs.

  • Empirically removing some LNs first makes the loss explode, but removing the same LNs after first removing and fine-tuning other LNs is fine. This makes some sense; we expect the model to adapt to remove all functionality that relies on LNs. So maybe the first removal is the most important one (after which the model reorganizes computation), and further removals don’t require as large of a reorganization anymore.

So I expect which order to remove the LNs in matters.

A meta-choice is whether

  • to remove all LNs of a type first (for all layers), and later remove LNs of different types, or

  • to remove all LNs in a layer (of all types), and then move on to the next layer.

I went with the first option for ease of implementation, but have not tried the second option. However the second option would seem more principled to me once we understand which order is optimal.

Which kinds of LayerNorms to remove first

I tried out a few combinations, such as first removing ln_f and then removing ln_2 and ln_1, or vice versa. I haven’t done a systematic sweep of all options, and the current method is just what felt right after a couple of tries. It seems to work well enough though.

Here’s an example of removing ln_f first. The loss reaches a very high level, and even after 400 iterations only goes down to 3.138. So it seems this is a worse choice. Not however that in this run I didn’t do “warm up” iterations (training for a couple iterations with LN to reach a good loss on OpenWebText – gpt2 directly loaded from Hugging Face does badly for the first ~10 iterations).

Which layer to remove LayerNorms in first

Here I remove ln2 in the different layers in different orders:

  • Normal = remove in layer 0, then layer 1, …

  • Reverse = remove in layer 11, then layer 10, …

I start removing LNs at iteration 300, and remove another LN every 10 iterations.

The loss differs during the process (this is expected, some LNs are possibly more important than others) but evens out at the end.

Data-reuse and seeds

In my initial tests I used lots of snapshots, and accidentally retrained the model on the same first couple of batches of openwebtext (fixed seed) every time. I have the impression that this worked slightly better than my later full-pipeline runs never re-using data. I might investigate this in the future.

Infohazards

I am not worried that publishing this work accelerates capabilities progress over alignment progress. Because (a) this is a pretty obvious idea, (b) it applies only to inference (not training), and (c) it only speeds up inference by a very small amount (that likely is not even worth the loss increase).

Acknowledgements

Thanks to Bilal Chughtai, Neel Nanda, and Rudolf Laine for comments and feedback on the draft. The nanoGPT repository and accompanying video by Andrej Karpathy were very helpful, allowing me to get a working prototype in a day!

  1. ^

    The GPT2 paper claims a loss of log(16)=2.77 on their training dataset (non-public webtext). I guess that must be an easier dataset. In any case, I fine-tune both models on OpenWebText for a total of 1000 iterations (~500k rows, ~500M tokens) to give a fairer comparison.

  2. ^

    For this case, only the final layer norm matters

  3. ^

    this paper discusses confidence regularization as one possible-important aspect

  4. ^

    I wonder how much of this effect is “just divide by a larger number” vs. actually dividing by the correct average. After all, the position 0 average shouldn’t be a great match for the EOT token average. [In this dataset the first position is not an EOT token. Neel Nanda /​ TransformerLebs recommends this for short prompts (see here for a discussion) but we don’t do it for the full dataset.]