I’ve been looking into your proposed solution (inspired by @Charlie Steiner ’s comment). For small models (Pythia-70M is d_model=512) w/ 2k features doesn’t take long to calculate naively, so it’s viable for initial testing & algorithmic improvements can be stacked later.
There are a few choices regardless of optimal solution:
Cos-sim of closest neighbor only or average of all vectors?
If closest neighbor, should this be calculated as unique closest neighbor? (I’ve done hungarian algorithm before to calculate this). If not, we’re penalizing features that are close (or more “central”) to many other features more than others.
Per batch, only a subset of features activate. Should the cos-sim only be on the features that activate? The orthogonality regularizer would be trading off L1 & MSE, so it might be too strong if it’s calculated on all features.
Gradient question: is there still gradient updates on the decoder weights of feature vectors that didn’t activate.
Loss function. Do we penalize high cos-sim more? There’s also a base-random cos-sim of ~.2 for the 500x2k vectors.
I’m currently thinking cos-sim of closest neighbor only, not unique & only on features that activate (we can also do ablations to check). For loss function, we could modify a sigmoid function:
11+e−(10(x−0.5))
This makes the loss centered between 0 & 1 & higher cos-sim penalized more & lower cos-sim penalized less.
Metrics:
During training, we can periodically check the max mean cos-sim (MMCS). This is the average cos-sim of the non-unique nearest neighbors. Alternatively pushing the histogram (histograms are nice, but harder to compare across runs in wandb). I would like to see normal (w/o an orthogonality regularizer) training run’s histogram for setting the hyperparams for the loss function.
Faiss seems SOTA AFAIK for fast nearest neighbors on a gpu although:
adding or searching a large number of vectors should be done in batches. This is not done automatically yet. Typical batch sizes are powers of two around 8192, see this example.
Testing it with Pythia-70M and few enough features to permit the naive calculation sounds like a great approach to start with.
Closest neighbour rather than average over all sounds sensible. I’m not certain what you mean by unique vs non-unique. If you’re referring to situations where there may be several equally close closest neighbours then I think we can just take the mean cos-sim of those neighbours, so they all impact on the loss but the magnitude of the loss stays within the normal range.
Only on features that activate also sounds sensible, but the decoder weights of neurons that didn’t activate would need to be allowed to update if they were the closest neighbours for neurons that did activate. Otherwise we could get situations where e.g. one neuron (neuron A) has encoder and decoder weights both pointing in sensible directions to capture a feature, but another neuron has decoder weights aligned with neuron A but has encoder weights occupying a remote region of activation space and thus rarely activates, causing its decoder weights to remain in that direction blocking neuron A if we don’t allow it to update.
Yes I think we want to penalise high cos-sim more. The modified sigmoid flattens out as x->1 but the I think the purple function below does what we want.
Training with a negative orthogonality regulariser could be an option. I think vanilla SAEs already have plenty of geometrically aligned features (e.g. see @jacobcd52 ’s comment below). Depending on the purpose, another option to intentionally generate feature combinatorics could be to simply add together some of the features learnt by a vanilla SAE. If the individual features weren’t combinations then their sums certainly would be.
I’ll be very interested to see results and am happy to help with interpreting them etc. Also more than happy to have a look at any code.
Additionally, we can train w/ a negative orthogonality regularizer for the purpose of intentionally generating feature-combinatorics. In other words, we train for the features to point in more of the same direction to at least generate for-sure examples of feature combination.
I’ve been looking into your proposed solution (inspired by @Charlie Steiner ’s comment). For small models (Pythia-70M is d_model=512) w/ 2k features doesn’t take long to calculate naively, so it’s viable for initial testing & algorithmic improvements can be stacked later.
There are a few choices regardless of optimal solution:
Cos-sim of closest neighbor only or average of all vectors?
If closest neighbor, should this be calculated as unique closest neighbor? (I’ve done hungarian algorithm before to calculate this). If not, we’re penalizing features that are close (or more “central”) to many other features more than others.
Per batch, only a subset of features activate. Should the cos-sim only be on the features that activate? The orthogonality regularizer would be trading off L1 & MSE, so it might be too strong if it’s calculated on all features.
Gradient question: is there still gradient updates on the decoder weights of feature vectors that didn’t activate.
Loss function. Do we penalize high cos-sim more? There’s also a base-random cos-sim of ~.2 for the 500x2k vectors.
I’m currently thinking cos-sim of closest neighbor only, not unique & only on features that activate (we can also do ablations to check). For loss function, we could modify a sigmoid function:
11+e−(10(x−0.5))
This makes the loss centered between 0 & 1 & higher cos-sim penalized more & lower cos-sim penalized less.
Metrics:
During training, we can periodically check the max mean cos-sim (MMCS). This is the average cos-sim of the non-unique nearest neighbors. Alternatively pushing the histogram (histograms are nice, but harder to compare across runs in wandb). I would like to see normal (w/o an orthogonality regularizer) training run’s histogram for setting the hyperparams for the loss function.
Algorithmic Improvements:
The wiki for Closest Pair of Points (h/t Oam) & Nearest neighbor search seem relevant if one computes the nearest neighbor to create an index as Charlie suggested.
Faiss seems SOTA AFAIK for fast nearest neighbors on a gpu although:
I believe this is for GPU-memory constraints.
I had trouble installing it using
conda install pytorch::faiss-gpu
but it works if you do
conda install -c pytorch -c nvidia faiss-gpu=1.7.4 mkl=2021 blas=1.0=mkl
I also was unsuccessful installing it w/ just pip w/o conda & conda is their offical supported way to install from here.
An additional note is that the cosine similarity is the dot-product for our case, since all feature vectors are normalized by default.
I’m currently ignoring the algorithmic improvements due to the additional complexity, but should be doable if it produces good results.
Testing it with Pythia-70M and few enough features to permit the naive calculation sounds like a great approach to start with.
Closest neighbour rather than average over all sounds sensible. I’m not certain what you mean by unique vs non-unique. If you’re referring to situations where there may be several equally close closest neighbours then I think we can just take the mean cos-sim of those neighbours, so they all impact on the loss but the magnitude of the loss stays within the normal range.
Only on features that activate also sounds sensible, but the decoder weights of neurons that didn’t activate would need to be allowed to update if they were the closest neighbours for neurons that did activate. Otherwise we could get situations where e.g. one neuron (neuron A) has encoder and decoder weights both pointing in sensible directions to capture a feature, but another neuron has decoder weights aligned with neuron A but has encoder weights occupying a remote region of activation space and thus rarely activates, causing its decoder weights to remain in that direction blocking neuron A if we don’t allow it to update.
Yes I think we want to penalise high cos-sim more. The modified sigmoid flattens out as x->1 but the I think the purple function below does what we want.
Training with a negative orthogonality regulariser could be an option. I think vanilla SAEs already have plenty of geometrically aligned features (e.g. see @jacobcd52 ’s comment below). Depending on the purpose, another option to intentionally generate feature combinatorics could be to simply add together some of the features learnt by a vanilla SAE. If the individual features weren’t combinations then their sums certainly would be.
I’ll be very interested to see results and am happy to help with interpreting them etc. Also more than happy to have a look at any code.
Additionally, we can train w/ a negative orthogonality regularizer for the purpose of intentionally generating feature-combinatorics. In other words, we train for the features to point in more of the same direction to at least generate for-sure examples of feature combination.