Thanks for the comments! I am also surprised that SAEs trained on these vision models seem to require such little data. Especially as I would have thought the complexity of CLIP’s representations for vision would be comparable to the complexity for text (after all we can generate an image from a text prompt, and then use a captioning model to recover the text suggesting most/all of the information in the text is also present in the image).
With regards to the model loss, I used the text template “A photo of a {label}.”, where {label} is the ImageNet text label (this was the template used in the original CLIP paper). These text prompts were used alongside the associated batch of images and passed jointly into the full CLIP model (text and vision models) using the original contrastive loss function that CLIP was trained on. I used this loss calculation (with this template) to measure both the original model loss and the model loss with the SAE inserted during the forward pass.
I also agree completely with your explanation for the reduction in loss. My tentative explanation goes something like this:
Many of the ImageNet classes are very similar (eg 118 classes are of dogs and 18 are of primates). A model such as CLIP that is trained on a much larger dataset may struggle to differentiate the subtle differences in dog breeds and primate species. These classes alone may provide a large chunk of the loss when evaluated on ImageNet.
CLIP’s representations of many of these classes will likely be very similar,[1] using only a small subspace of the residual stream to separate these classes. When the SAE is included during the forward pass, some random error is introduced into the model’s activations and so these representations will on average drift apart from each other, separating slightly. This on average will decrease the contrastive loss when restricted to ImageNet (but not on a much larger dataset where the activations will not be clustered in this way).
That was a very hand-wavy explanation but I think I can formalise it with some maths if people are unconvinced by it.
I have some data to suggest this is the case even from the perspective of SAE features. The dog SAE features have much higher label entropy (mixing many dog species in the highest activating images) compared to other non-dog classes, suggesting the SAE features struggle to separate the dog species.
Thanks for the comments! I am also surprised that SAEs trained on these vision models seem to require such little data. Especially as I would have thought the complexity of CLIP’s representations for vision would be comparable to the complexity for text (after all we can generate an image from a text prompt, and then use a captioning model to recover the text suggesting most/all of the information in the text is also present in the image).
With regards to the model loss, I used the text template “A photo of a {label}.”, where {label} is the ImageNet text label (this was the template used in the original CLIP paper). These text prompts were used alongside the associated batch of images and passed jointly into the full CLIP model (text and vision models) using the original contrastive loss function that CLIP was trained on. I used this loss calculation (with this template) to measure both the original model loss and the model loss with the SAE inserted during the forward pass.
I also agree completely with your explanation for the reduction in loss. My tentative explanation goes something like this:
Many of the ImageNet classes are very similar (eg 118 classes are of dogs and 18 are of primates). A model such as CLIP that is trained on a much larger dataset may struggle to differentiate the subtle differences in dog breeds and primate species. These classes alone may provide a large chunk of the loss when evaluated on ImageNet.
CLIP’s representations of many of these classes will likely be very similar,[1] using only a small subspace of the residual stream to separate these classes. When the SAE is included during the forward pass, some random error is introduced into the model’s activations and so these representations will on average drift apart from each other, separating slightly. This on average will decrease the contrastive loss when restricted to ImageNet (but not on a much larger dataset where the activations will not be clustered in this way).
That was a very hand-wavy explanation but I think I can formalise it with some maths if people are unconvinced by it.
I have some data to suggest this is the case even from the perspective of SAE features. The dog SAE features have much higher label entropy (mixing many dog species in the highest activating images) compared to other non-dog classes, suggesting the SAE features struggle to separate the dog species.