Thanks a lot for this post, I found it very helpful.
There exists a single direction which contains all linearly available information Previous work has found that, in most datasets, linearly available information can be removed with a single rank-one ablation by ablating along the difference of the means of the two classes.
The specific thing that you measure may be more a fact about linear algebra rather than a fact about LLMs or CCS.
For example, let’s construct data which definitely has two linearly independent dimension that are each predictive of whether a point is positive or negative. I’m assuming here that positive/negative exactly corresponds to true/false for convenience (i.e., that all the original statements happened to be true), but I don’t think it should matter to the argument.
# Initialize the dataset
p = np.random.normal(size=size)
n = np.random.normal(size=size)
# In the first and second dimensions, they can take values -1, -.5, 0, .5, 1
# The distributions are idiosyncratic for each, but all are predictive of the truth
# label
standard = [-1, -.5, 0, .5, 1]
p[:,0] = np.random.choice(standard, size=(100,), replace=True, p=[0, .1, .2, .3, .4])
n[:,0] = np.random.choice(standard, size=(100,), replace=True, p=[.6, .2, .1, .1, .0])
p[:,1] = np.random.choice(standard, size=(100,), replace=True, p=[0, .05, .05, .1, .8])
n[:,1] = np.random.choice(standard, size=(100,), replace=True, p=[.3, .3, .3, .1, .0])
Then we can plot the data. For the unprojected data plotted in 3-d, the points are linearly classifiable with reasonable accuracy in both those dimensions.
But then we perform the mean projection operation given here (in case I have any bugs!)
def project(p, n):
# compute the means in each dim
p_mean = np.mean(p, axis=0)
n_mean = np.mean(n, axis=0)
# find the direction
delta = p_mean - n_mean
norm = np.linalg.norm(delta, axis=0)
unit_delta = delta / norm
# project
p_proj = p - np.expand_dims(np.inner(p, unit_delta),1) * unit_delta
n_proj = n - np.expand_dims(np.inner(n, unit_delta),1) * unit_delta
return p_proj, n_proj
And after projection there is no way to get a linear classifier that has decent accuracy.
Or looking at the projection onto the 2-d plane to make things easier to see:
Note also that this is all with unnormalized raw data, doing the same thing with normalized data gives a very similar result with this as the unprojected:
and these projected figures:
FWIW, I’m like 80% sure that Alex Mennen’s comment gives mathematical intuition behind these visualizations, but it wasn’t totally clear to me, so posting these in case it is clearer to some others as well.
I agree, there’s nothing specific to neural network activations here. In particular, the visual intuition that if you translate the two datasets until they have the same mean (which is weaker than mean ablation), you will have a hard time finding a good linear classifier, doesn’t rely on the shape of the data.
But it’s not trivial or generally true either: the paper I linked to give some counterexamples of datasets where mean ablation doesn’t prevent you from building a classifier with >50% accuracy. The rough idea is that the mean is weak to outliers, but outliers don’t matter if you want to produce high-accuracy classifiers. Therefore, what you want is something like the median.
Thanks a lot for this post, I found it very helpful.
The specific thing that you measure may be more a fact about linear algebra rather than a fact about LLMs or CCS.
For example, let’s construct data which definitely has two linearly independent dimension that are each predictive of whether a point is positive or negative. I’m assuming here that positive/negative exactly corresponds to true/false for convenience (i.e., that all the original statements happened to be true), but I don’t think it should matter to the argument.
Then we can plot the data. For the unprojected data plotted in 3-d, the points are linearly classifiable with reasonable accuracy in both those dimensions.
But then we perform the mean projection operation given here (in case I have any bugs!)
And after projection there is no way to get a linear classifier that has decent accuracy.
Or looking at the projection onto the 2-d plane to make things easier to see:
Note also that this is all with unnormalized raw data, doing the same thing with normalized data gives a very similar result with this as the unprojected:
and these projected figures:
FWIW, I’m like 80% sure that Alex Mennen’s comment gives mathematical intuition behind these visualizations, but it wasn’t totally clear to me, so posting these in case it is clearer to some others as well.
I agree, there’s nothing specific to neural network activations here. In particular, the visual intuition that if you translate the two datasets until they have the same mean (which is weaker than mean ablation), you will have a hard time finding a good linear classifier, doesn’t rely on the shape of the data.
But it’s not trivial or generally true either: the paper I linked to give some counterexamples of datasets where mean ablation doesn’t prevent you from building a classifier with >50% accuracy. The rough idea is that the mean is weak to outliers, but outliers don’t matter if you want to produce high-accuracy classifiers. Therefore, what you want is something like the median.