I wonder if it would be possible to do SAE feature amplification / ablation, at least for residual stream features, by inserting a “mostly empty” layer. E,g, for feature ablation, setting the W_O and b_O params of the attention heads of your inserted layer to 0 to make sure that the attention heads don’t change anything, and then approximate the constant / clamping intervention from the blog post via the MLP weights (if the activation function used for the transformer is the same one as is used for the SAE, it should be possible to do a perfect approximation using only one of the MLP neurons, but even if not it should be possible to very closely approximate any commonly-used activation function using any other commonly-used activation function with some clever stacking).
This would of course be horribly inefficient from a compute perspective (each forward pass would take n+kn times as long, where n is the original number of layers the model had and k is the number of distinct layers in which you’re trying to do SAE operations on the residual stream), but I think vllm would handle “llama but with one extra layer” without requiring any tricky inference code changes and plausibly this would still be more efficient than resampling.
I wonder if it would be possible to do SAE feature amplification / ablation, at least for residual stream features, by inserting a “mostly empty” layer. E,g, for feature ablation, setting the
W_O
andb_O
params of the attention heads of your inserted layer to 0 to make sure that the attention heads don’t change anything, and then approximate the constant / clamping intervention from the blog post via the MLP weights (if the activation function used for the transformer is the same one as is used for the SAE, it should be possible to do a perfect approximation using only one of the MLP neurons, but even if not it should be possible to very closely approximate any commonly-used activation function using any other commonly-used activation function with some clever stacking).This would of course be horribly inefficient from a compute perspective (each forward pass would take n+kn times as long, where n is the original number of layers the model had and k is the number of distinct layers in which you’re trying to do SAE operations on the residual stream), but I think vllm would handle “llama but with one extra layer” without requiring any tricky inference code changes and plausibly this would still be more efficient than resampling.