I’ve been leveraging your code to speed up implementation of my own new formulation of neuron masks. I noticed a bug:
def running_mean_tensor(old_mean, new_value, n):
return old_mean + (new_value - old_mean) / n
def get_sae_means(mean_tokens, total_batches, batch_size, per_token_mask=False):
for sae in saes:
sae.mean_ablation = torch.zeros(sae.cfg.d_sae).float().to(device)
with tqdm(total=total_batches*batch_size, desc="Mean Accum Progress") as pbar:
for i in range(total_batches):
for j in range(batch_size):
with torch.no_grad():
_ = model.run_with_hooks(
mean_tokens[i, j],
return_type="logits",
fwd_hooks=build_hooks_list(mean_tokens[i, j], cache_sae_activations=True)
)
for sae in saes:
sae.mean_ablation = running_mean_tensor(sae.mean_ablation, sae.feature_acts, i+1)
cleanup_cuda()
pbar.update(1)
if i >= total_batches:
break
get_sae_means(corr_tokens, 40, 16)
The running mean calculation is only correct if n
is the total number of samples so far. But i+1
is the 1-indexed batch number we’re on. That value should be i * batch_size + j + 1
. I ran a little test. Below is a histogram from 1k runs of taking 104 random normal samples with batch_size=8
, and then comparing the different between the true mean and the final running mean as calculated by the running_mean_tensor
function. It looks like the expected difference is zero but with a fairly large variance. Def larger than standard error of the mean estimate, which is ~1/10 (= standard_normal_sdev / sqrt(n) =~ 1⁄10). Not sure how much it affects accuracy of estimates to add a random error to the logit diffs.
Really exciting stuff here! I’ve been working on an alternate formulation of circuit discovery in the now traditional fixed problems case and have been brainstorming unsupervised circuit discovery, in the same spiritual vein as this work, though much less developed. You’ve laid the groundwork for a really exciting research direction here!
I have a few questions on the components definition and optimization. What does it mean when you say you define C components Pc? Do randomly partition the parameter vector into C partitions and assign each partition as a Pc, with zeros elsewhere? Do you divide each weight by C, setting wc,l,i,j=wl,i,j/C (+ ε?)?
Assuming something like that is going on, I definitely believe this has been tricky to optimize on larger, more complex networks! I wonder if more informed priors might help? As in, using other methods to suggest at least some proportion of candidate components? Have you considered or tried anything like that?