Wiki Contributions

Comments

Ah yes! I tried doing exactly this to produce a sort of 'logit lens' to explain the SAE features. In particular I tried the following.

  • Take an SAE feature encoder direction and map it directly to the multimodal space to get an embedding.
  • Pass each of the ImageNet text prompts “A photo of a {label}.” through the CLIP text model to generate the multimodal embeddings for each ImageNet class.
  • Calculate the cosine similarities between the SAE embedding and the ImageNet class embeddings. Pass this through a softmax to get a probability distribution.
  • Look at the ImageNet labels with a high probability - this should give some explanation as to what the SAE feature is representing.

Surprisingly, this did not work at all! I only spent a small amount of time trying to get this to work (<1day), so I'm planning to try again. If I remember correctly, I also tried the same analysis for the decoder feature vector and also tried shifting by the decoder bias vector too - both of these didn't seem to provide good ImageNet class explanations of the SAE features. I will try doing this again and I can let you know how it goes!

Thanks for the comments! I am also surprised that SAEs trained on these vision models seem to require such little data. Especially as I would have thought the complexity of CLIP's representations for vision would be comparable to the complexity for text (after all we can generate an image from a text prompt, and then use a captioning model to recover the text suggesting most/all of the information in the text is also present in the image).

With regards to the model loss, I used the text template “A photo of a {label}.”, where {label} is the ImageNet text label (this was the template used in the original CLIP paper). These text prompts were used alongside the associated batch of images and passed jointly into the full CLIP model (text and vision models) using the original contrastive loss function that CLIP was trained on. I used this loss calculation (with this template) to measure both the original model loss and the model loss with the SAE inserted during the forward pass.

I also agree completely with your explanation for the reduction in loss. My tentative explanation goes something like this:

  • Many of the ImageNet classes are very similar (eg 118 classes are of dogs and 18 are of primates). A model such as CLIP that is trained on a much larger dataset may struggle to differentiate the subtle differences in dog breeds and primate species. These classes alone may provide a large chunk of the loss when evaluated on ImageNet.
  • CLIP's representations of many of these classes will likely be very similar,[1] using only a small subspace of the residual stream to separate these classes. When the SAE is included during the forward pass, some random error is introduced into the model's activations and so these representations will on average drift apart from each other, separating slightly. This on average will decrease the contrastive loss when restricted to ImageNet (but not on a much larger dataset where the activations will not be clustered in this way).

That was a very hand-wavy explanation but I think I can formalise it with some maths if people are unconvinced by it.

  1. ^

    I have some data to suggest this is the case even from the perspective of SAE features. The dog SAE features have much higher label entropy (mixing many dog species in the highest activating images) compared to other non-dog classes, suggesting the SAE features struggle to separate the dog species.

Thanks for the feedback! Yeah I was also surprised SAEs seem to work on ViTs pretty much straight out of the box (I didn't even need to play around with the hyper parameters too much)! As I mentioned in the post I think it would be really interesting to train on a much larger (more typical) dataset - similar to the dataset the CLIP model was trained on.

I also agree that I probably should have emphasised the "guess the image" game as a result rather than an aside, I'll bare that in mind for future posts!