As one of the author's of llava-gemma, I'm thrilled to see our model is useful for alignment research. I think the multi-modal aspect of SAE research is under studied, and I am very glad to see this. I will be keeping an eye out for future developments!
Shan Chen, Jack Gallifant, Kuleen Sasse, Danielle Bitterman[1]
Please read this as a work in progress where we are colleagues sharing this in a lab (https://www.bittermanlab.org) meeting to help/motivate potential parallel research.
TL;DR:
Introduction
The pursuit of universal and interpretable features has long captivated researchers in AI, with Sparse Autoencoders (SAEs) emerging as a promising tool for extracting meaningful representations. Universality, in this context, refers to the ability of features to transcend domains, languages, modalities, model architectures, sizes, and training strategies. Recent advances have shed light on key properties of these representations, including their dataset-dependent nature, their relationship with the granularity of training data, and their transferability across tasks. Notably, studies have demonstrated the intriguing ability of features to transfer from base models to fine-tuned models, such as Kissane et al. (2024) and Kutsyk et al. (2024), and have even hinted at their generalization across layers (Ghilardi et al. 2024). However, one critical question remains underexplored: can features trained in unimodal contexts (e.g., text-only or image-only models) effectively generalize to multimodal systems?
In this work, we focus on bridging this "modality gap" by investigating the applicability of SAE-derived features in multimodal settings. Specifically, we explore LLaVA (Liu et al. 2024), a popular multimodal model that integrates vision and language tasks. Leveraging the CIFAR-100 dataset, which provides a challenging fine-grained classification task, we assess the transferability and interpretability of features learned from base models in this multimodal context. Through a detailed layer-wise analysis, we investigate the semantic evolution of tokens and evaluate the utility of these features in downstream classification tasks.
While previous work has largely focused on the unimodal-to-unimodal transfer of features, our experiments aim to answer whether features extracted from base models can effectively bridge the gap to multimodal applications. This exploration aligns with ongoing efforts to understand how features encode information, how transferable they are across different contexts, and how they can be interpreted when applied to diverse tasks.
This write-up details our exploratory experiments, including:
Our findings contribute to advancing the interpretability and universality of features in large models, paving the way for more robust, explainable, and cross-modal AI systems.
Some Background on LLaVA:
LLaVA (Liu et al. 2023) is a multimodal framework that integrates vision and language tasks. By combining a Vision Encoder and a Language Model, LLaVA processes both image and textual inputs to generate coherent and contextually appropriate language-based outputs.
A visual representation of the LLaVA model architecture from (Liu et al. 2023). This diagram illustrates the flow of information from image input through the Vision Encoder, projection layer, and into the Language Model, culminating in generating text outputs.
Key Components
Vision Encoder:
Language Model:
Token Structure:
Output:
The final output of the LLaVA model is a text-based response that reflects both the visual content of the input image and the language instructions provided. This enables a wide range of applications, from answering questions about an image to generating detailed image captions.
Training:
LLaVA’s multimodal alignment is realized during visual instruction tuning, the fine-tuning of the Language Model using multimodal instruction-following data, where each textual instruction is paired with corresponding visual inputs. During this process, the model learns to interpret visual data in conjunction with textual context, which aligns visual features with language features.
Evaluating SAE Transferability with LLaVA
LLaVA’s architecture provides an ideal testbed for evaluating the transferability of SAEs. By leveraging its unified token space and multimodal alignment, we can assess how well unimodal features extracted by SAEs adapt to multimodal contexts. Specifically, LLaVA’s ability to process and integrate image and text tokens allows us to analyze the semantic evolution of SAE-derived features across its layers, offering insights into their utility and generalization capabilities in multimodal scenarios.
In this study, we utilize the Intel Gemma-2B LLaVA 1.5-based model (Intel/llava-gemma-2b) as the foundation for our experiments. For feature extraction, we incorporate pre-trained SAEs from jbloom/Gemma-2b-Residual-Stream-SAEs, trained on the Gemma-1-2B model. These SAEs include 16,384 features (an expansion factor of 8 × 2048) and are designed to capture sparse and interpretable representations.
Our analysis focuses on evaluating the layer-wise integration of these features within LLaVA to determine their effectiveness in bridging unimodal-to-multimodal gaps. Specifically, we assess their impact on semantic alignment, and classification performance. We hypothesized that the text-trained SAE features were still meaningful to LLaVA.
Experimental Design
Dataset
We used the CIFAR-100 (Krizhevsky et al. 2009) dataset, which comprises:
Features and Evaluation
Is there any signal?
We implemented the outlined procedure and analyzed the retrieved features to evaluate whether meaningful features could be identified through this transfer method. As a first step, a preliminary cleaning process was conducted to refine the feature set before delving into the detailed retrieved features and their auto-interpretability explanations.
The objective of the cleaning process was to eliminate features that appeared to be disproportionately represented across instances, which could introduce noise, diminish interpretability, or indicate unaligned or non-transferable features. Considering the CIFAR-100 dataset, which comprises 100 labels with 100 instances per label, the expected maximum occurrence of any feature under uniform distribution is approximately 100. To address potential anomalies, a higher threshold of 1000 occurrences was selected as the cutoff for identifying and excluding overrepresented features. This conservative threshold ensured that dominant, potentially less informative features were removed while retaining those likely to contribute meaningfully to the analysis.
After cleaning, we examined the retrieved features across different model layers (0–12 of 19 layers). A clear trend emerged: deeper layers exhibited increasingly useful features.
Below, we provide examples of retrieved features from both high-performing and underperforming classes, demonstrating the range of interpretability outcomes:
1. Dolphin 🐬
Layer 0
Layer 6
Layer 10
Layer 12
Layer 12-it
2. Skyscraper 🏙️
Layer 0
Layer 6
Layer 10
Layer 12
Layer 12-it
3. Boy 👦
Layer 0
Layer 6
Layer 10
Layer 12
Layer 12-it
4. Cloud ☁️
Layer 0
Layer 6
Layer 10
Layer 12
Layer 12-it
Classification Analysis
Building on the feature extraction process, we shifted focus to an equally critical question: Could the extracted features meaningfully classify CIFAR-100 labels? Specifically, we aimed to determine whether these features could reliably distinguish between diverse categories such as "dolphin" and "skyscraper." Additionally, we investigated how choices like binarization and layer selection influenced the robustness and effectiveness of the classification process.
Here, we outline our methodology, key findings, and their broader implications.
Classification Setup
We implemented a linear classification pipeline to evaluate the retrieved features' predictive utility. Features were collected from multiple layers of the model and underwent the following preparation steps:
Feature Pooling:
Features were aggregated along the token dimension using two strategies:
Activation Transformation:
We explored the impact of activation scaling on performance:
Layer Evaluation
Features were extracted from Layers 6, 10, and 17 of the model. A linear classifier was trained using the features of each layer, and performance was assessed with Macro F1 scores. This ensured a balanced evaluation across all CIFAR-100 categories, allowing us to identify robustness, efficiency, and interpretability trends across different configurations.
Classification Findings
Performance Summary
So, in a way, we actually nearly recovered the full VIT performance here!
1. How Many Features Do We Need?
We tested a range of feature selection methods, from summing activations over all tokens to taking only the top-1 activation per token.
What We Found:
Takeaway: Retaining a larger set of features preserves more discriminative information for CIFAR100, and this may have different imports across layers).
2. Which Layer Performs Best?
We tested features from Layers 6, 10, and 17 to see which part of the model provided the best representations.
What We Found:
Layer 10 Superiority: Features from Layer 10 consistently achieved the highest Macro F1 scores, balancing generalization and specificity.
Takeaway: Mid-level features (Layer 10) offered the best trade-off for CIFAR-100 classification.
3. To Binarize or Not to Binarize?
We compared binarized activations, which have cap values, with non-binarized ones. The idea is that binarization reduces noise and keeps things simple.
What We Found:
Binarized vs. Non-Binarized: Binarized features outperformed non-binarized counterparts, particularly with smaller feature sets.
Takeaway: Binarization improves performance, especially under limited feature budgets.
4. Data Efficiency: How Much Training Data Do We Need?
We tested how well the features worked when we varied the amount of training data, from small splits (1 train, 9 test) to larger splits (5 train, 5 test).
What We Found:
Layer 17 Limitations: Performance for Layer 17 improved with increased data but lagged under low-data conditions.
Takeaway: Binarized middle-layer features (e.g., Layer 10) were the most data-efficient option.
Big Picture Insights
So, what did we learn from all this? Here are the three big takeaways:
What’s Next?
These findings open up exciting transfer learning and feature design possibilities in multimodal systems. We’ve shown that thoughtful feature selection and transformation can make a big difference even with simple linear classifiers.
For future work, we’re interested in exploring:
Exploring whether combining features from multiple layers offers additional performance gains. Could a hybrid approach outperform the best single-layer features?
The authors acknowledge financial support from the Google PhD Fellowship (SC), the Woods Foundation (DB, SC, JG), the NIH (NIH R01CA294033 (SC, JG, DB), NIH U54CA274516-01A1 (SC, DB) and the American Cancer Society and American Society for Radiation Oncology, ASTRO-CSDG-24-1244514-01-CTPS Grant DOI #: https://doi.org/10.53354/ACS.ASTRO-CSDG-24-1244514-01-CTPS.pc.gr.222210 (DB)
It is very intersting that models are more focusing on the cloth and potential parental care here.
Which lead me to think another question here, should we train SAE jointly for vision and text, or should we do it separately for multimodal systems.
We actually did image_net 1k too, and it is still running due to size. We are seeing 0.49 Macro F1 for layer 6...