As with Arthur, I’m pretty surprised by. how much easier vision seems to be than text for interp (in line with previous results). It makes sense why feature visualization and adversarial attacks work better with continuous inputs, but if it is true that you need fewer datapoints to recover concepts of comparable complexity, I wonder if it’s a statement about image datasets or about vision in general (e.g. “abstract” concepts are more useful for prediction, since the n-gram/skip n-gram/syntactical feature baseline is much weaker).
I think the most interesting result to me is your result where the went down (!!):
Note that the model with the SAE attains a lower loss than the original model. It is not clear to me why this is the case. In fact, the model with the SAE gets a lower loss than the original model within 40 000 training tokens.
My guess is this happens because CLIP wasn’t trained on imagenet—but instead a much larger dataset that comes from a different distribution. A lot of the SAE residual probably consists of features that are useful in on the larger dataset, but not imagenet. If you extract the directions of variation on imagenet instead of OAI’s 400m image-text pair dataset, it makes sense why reconstructing inputs using only these directions lead to better performance on the dataset you found these inputs on.
I’m not sure how you computed the contrastive loss here—is it just the standard contrastive loss, but on image pairs instead of image/text pairs (using the SAE’ed ViT for both representations), or did you use the contextless class label as the text input here (only SAE’ing the ViT part but not the text encoder). Either way, this might add additional distributional shift.
(And I could be misunderstanding what you did entirely, and that you actually looked at contrastive loss on the original dataset somehow, in which case the explanation I gave up doesn’t apply.)
Also have you looked at the dot product of each of the SAE directions/SAE reconstructed representaitons with the image net labels fed through the text encoder??
Ah yes! I tried doing exactly this to produce a sort of ‘logit lens’ to explain the SAE features. In particular I tried the following.
Take an SAE feature encoder direction and map it directly to the multimodal space to get an embedding.
Pass each of the ImageNet text prompts “A photo of a {label}.” through the CLIP text model to generate the multimodal embeddings for each ImageNet class.
Calculate the cosine similarities between the SAE embedding and the ImageNet class embeddings. Pass this through a softmax to get a probability distribution.
Look at the ImageNet labels with a high probability—this should give some explanation as to what the SAE feature is representing.
Surprisingly, this did not work at all! I only spent a small amount of time trying to get this to work (<1day), so I’m planning to try again. If I remember correctly, I also tried the same analysis for the decoder feature vector and also tried shifting by the decoder bias vector too—both of these didn’t seem to provide good ImageNet class explanations of the SAE features. I will try doing this again and I can let you know how it goes!
Huh, that’s indeed somewhat surprising if the SAE features are capturing the things that matter to CLIP (in that they reduce loss) and only those things, as opposed to “salient directions of variation in the data”. I’m curious exactly what “failing to work” means—here I think the negative result (and the exact details of said result) are argubaly more interesting than a positive result would be.
Thanks for the comments! I am also surprised that SAEs trained on these vision models seem to require such little data. Especially as I would have thought the complexity of CLIP’s representations for vision would be comparable to the complexity for text (after all we can generate an image from a text prompt, and then use a captioning model to recover the text suggesting most/all of the information in the text is also present in the image).
With regards to the model loss, I used the text template “A photo of a {label}.”, where {label} is the ImageNet text label (this was the template used in the original CLIP paper). These text prompts were used alongside the associated batch of images and passed jointly into the full CLIP model (text and vision models) using the original contrastive loss function that CLIP was trained on. I used this loss calculation (with this template) to measure both the original model loss and the model loss with the SAE inserted during the forward pass.
I also agree completely with your explanation for the reduction in loss. My tentative explanation goes something like this:
Many of the ImageNet classes are very similar (eg 118 classes are of dogs and 18 are of primates). A model such as CLIP that is trained on a much larger dataset may struggle to differentiate the subtle differences in dog breeds and primate species. These classes alone may provide a large chunk of the loss when evaluated on ImageNet.
CLIP’s representations of many of these classes will likely be very similar,[1] using only a small subspace of the residual stream to separate these classes. When the SAE is included during the forward pass, some random error is introduced into the model’s activations and so these representations will on average drift apart from each other, separating slightly. This on average will decrease the contrastive loss when restricted to ImageNet (but not on a much larger dataset where the activations will not be clustered in this way).
That was a very hand-wavy explanation but I think I can formalise it with some maths if people are unconvinced by it.
I have some data to suggest this is the case even from the perspective of SAE features. The dog SAE features have much higher label entropy (mixing many dog species in the highest activating images) compared to other non-dog classes, suggesting the SAE features struggle to separate the dog species.
Cool work!
As with Arthur, I’m pretty surprised by. how much easier vision seems to be than text for interp (in line with previous results). It makes sense why feature visualization and adversarial attacks work better with continuous inputs, but if it is true that you need fewer datapoints to recover concepts of comparable complexity, I wonder if it’s a statement about image datasets or about vision in general (e.g. “abstract” concepts are more useful for prediction, since the n-gram/skip n-gram/syntactical feature baseline is much weaker).
I think the most interesting result to me is your result where the went down (!!):
My guess is this happens because CLIP wasn’t trained on imagenet—but instead a much larger dataset that comes from a different distribution. A lot of the SAE residual probably consists of features that are useful in on the larger dataset, but not imagenet. If you extract the directions of variation on imagenet instead of OAI’s 400m image-text pair dataset, it makes sense why reconstructing inputs using only these directions lead to better performance on the dataset you found these inputs on.
I’m not sure how you computed the contrastive loss here—is it just the standard contrastive loss, but on image pairs instead of image/text pairs (using the SAE’ed ViT for both representations), or did you use the contextless class label as the text input here (only SAE’ing the ViT part but not the text encoder). Either way, this might add additional distributional shift.
(And I could be misunderstanding what you did entirely, and that you actually looked at contrastive loss on the original dataset somehow, in which case the explanation I gave up doesn’t apply.)
Also have you looked at the dot product of each of the SAE directions/SAE reconstructed representaitons with the image net labels fed through the text encoder??
Ah yes! I tried doing exactly this to produce a sort of ‘logit lens’ to explain the SAE features. In particular I tried the following.
Take an SAE feature encoder direction and map it directly to the multimodal space to get an embedding.
Pass each of the ImageNet text prompts “A photo of a {label}.” through the CLIP text model to generate the multimodal embeddings for each ImageNet class.
Calculate the cosine similarities between the SAE embedding and the ImageNet class embeddings. Pass this through a softmax to get a probability distribution.
Look at the ImageNet labels with a high probability—this should give some explanation as to what the SAE feature is representing.
Surprisingly, this did not work at all! I only spent a small amount of time trying to get this to work (<1day), so I’m planning to try again. If I remember correctly, I also tried the same analysis for the decoder feature vector and also tried shifting by the decoder bias vector too—both of these didn’t seem to provide good ImageNet class explanations of the SAE features. I will try doing this again and I can let you know how it goes!
Huh, that’s indeed somewhat surprising if the SAE features are capturing the things that matter to CLIP (in that they reduce loss) and only those things, as opposed to “salient directions of variation in the data”. I’m curious exactly what “failing to work” means—here I think the negative result (and the exact details of said result) are argubaly more interesting than a positive result would be.
Thanks for the comments! I am also surprised that SAEs trained on these vision models seem to require such little data. Especially as I would have thought the complexity of CLIP’s representations for vision would be comparable to the complexity for text (after all we can generate an image from a text prompt, and then use a captioning model to recover the text suggesting most/all of the information in the text is also present in the image).
With regards to the model loss, I used the text template “A photo of a {label}.”, where {label} is the ImageNet text label (this was the template used in the original CLIP paper). These text prompts were used alongside the associated batch of images and passed jointly into the full CLIP model (text and vision models) using the original contrastive loss function that CLIP was trained on. I used this loss calculation (with this template) to measure both the original model loss and the model loss with the SAE inserted during the forward pass.
I also agree completely with your explanation for the reduction in loss. My tentative explanation goes something like this:
Many of the ImageNet classes are very similar (eg 118 classes are of dogs and 18 are of primates). A model such as CLIP that is trained on a much larger dataset may struggle to differentiate the subtle differences in dog breeds and primate species. These classes alone may provide a large chunk of the loss when evaluated on ImageNet.
CLIP’s representations of many of these classes will likely be very similar,[1] using only a small subspace of the residual stream to separate these classes. When the SAE is included during the forward pass, some random error is introduced into the model’s activations and so these representations will on average drift apart from each other, separating slightly. This on average will decrease the contrastive loss when restricted to ImageNet (but not on a much larger dataset where the activations will not be clustered in this way).
That was a very hand-wavy explanation but I think I can formalise it with some maths if people are unconvinced by it.
I have some data to suggest this is the case even from the perspective of SAE features. The dog SAE features have much higher label entropy (mixing many dog species in the highest activating images) compared to other non-dog classes, suggesting the SAE features struggle to separate the dog species.