I’ve trained some sparse MLPs with 20K neurons on a 4L TinyStories model with ReLU activations and no layernorm and I took a look at them after reading this post. For varying integer S, I applied an L1 penalty of 2S on the average of the activations per token, which seems pretty close to doing an L1 of 2S/20,000 on the sum of the activations per token. Your L1 of 2×10−4 with 12K neurons is sort of like S=2 in my setup. After reading your post, I checked out the cosine similarity between encoder/decoder of original mlp neurons and sparse mlp neurons for varying values of S (make sure to scroll down once you click one of the links!):
I think the behavior you’re pointing at is clearly there at lower L1s on layers other than layer 0 (? what’s up with that?) and sort of decreases with higher L1 values, to the point that the behavior is there a bit at S=5 and almost not there at S=6. I think the non-dead sparse neurons are almost all interpretable at S=5 and S=6.
Original val loss of model: 1.128 ~= 1.13. Zero ablation of MLP loss values per layer: [3.72, 1.84, 1.56, 2.07].
S=6 loss recovered per layer
Layer 0: 1-(1.24-1.13)/(3.72-1.13): 96% of loss recovered Layer 1: 1-(1.18-1.13)/(1.84-1.13): 93% of loss recovered Layer 2: 1-(1.21-1.13)/(1.56-1.13): 81% of loss recovered Layer 3: 1-(1.26-1.13)/(2.07-1.13): 86% of loss recovered
Compare to 79% of loss-recovered from Anthropic’s A/1 autoencoder with 4K features and a pretty different setup.
(Also, I was going to focus on S=5 MLPs for layers 1 and 2, but now I think I might instead stick with S=6. This is a little tricky because I wouldn’t be surprised if tiny-stories MLP neurons are interpretable at higher rates than other models.)
Basically I think sparse MLPs aren’t a dead end and that you probably just want a higher L1.
I’ve trained some sparse MLPs with 20K neurons on a 4L TinyStories model with ReLU activations and no layernorm and I took a look at them after reading this post. For varying integer S, I applied an L1 penalty of 2S on the average of the activations per token, which seems pretty close to doing an L1 of 2S/20,000 on the sum of the activations per token. Your L1 of 2×10−4 with 12K neurons is sort of like S=2 in my setup. After reading your post, I checked out the cosine similarity between encoder/decoder of original mlp neurons and sparse mlp neurons for varying values of S (make sure to scroll down once you click one of the links!):
S=3
https://plotly-figs.s3.amazonaws.com/sparse_mlp_L1_2exp3
S=4
https://plotly-figs.s3.amazonaws.com/sparse_mlp_L1_2exp4
S=5
https://plotly-figs.s3.amazonaws.com/sparse_mlp_L1_2exp5
S=6
https://plotly-figs.s3.amazonaws.com/sparse_mlp_L1_2exp6
I think the behavior you’re pointing at is clearly there at lower L1s on layers other than layer 0 (? what’s up with that?) and sort of decreases with higher L1 values, to the point that the behavior is there a bit at S=5 and almost not there at S=6. I think the non-dead sparse neurons are almost all interpretable at S=5 and S=6.
Original val loss of model: 1.128 ~= 1.13.
Zero ablation of MLP loss values per layer: [3.72, 1.84, 1.56, 2.07].
S=6 loss recovered per layer
Layer 0: 1-(1.24-1.13)/(3.72-1.13): 96% of loss recovered
Layer 1: 1-(1.18-1.13)/(1.84-1.13): 93% of loss recovered
Layer 2: 1-(1.21-1.13)/(1.56-1.13): 81% of loss recovered
Layer 3: 1-(1.26-1.13)/(2.07-1.13): 86% of loss recovered
Compare to 79% of loss-recovered from Anthropic’s A/1 autoencoder with 4K features and a pretty different setup.
(Also, I was going to focus on S=5 MLPs for layers 1 and 2, but now I think I might instead stick with S=6. This is a little tricky because I wouldn’t be surprised if tiny-stories MLP neurons are interpretable at higher rates than other models.)
Basically I think sparse MLPs aren’t a dead end and that you probably just want a higher L1.