In this post you seem to imply that the slow training is due to a lack of parallelization, but don't MP-SAEs also require more total flops?
At each iteration you need to recompute the encoder dot products using a matmul with the encoder matrix (a look at your code confirms this), so I would think that the total flops would scale almost linearly as you increase the number of iterations.
Very true! Each iteration of matching pursuit uses as much compute as the entire encode() of a standard SAE, so it's not only a parallelism problem (although it doesn't help either). I'll update the wording in the post.
This work was done as part of MATS 7.1
We recently added support for training and running Matching Pursuit SAEs (MP-SAEs) to SAELens, so I figured this is a good opportunity to train and open source some MP-SAEs, and share what I've learned along the way. Matching pursuit SAEs are exciting because they use a fundamentally different method to encode activations compared with traditional SAEs, and is a direct implementation of the classic matching pursuit algorithm from dictionary learning. The matching pursuit encoder is highly nonlinear, and should thus be more expressive than a traditional SAE encoder.
In this post, we'll discuss what MP-SAEs are, and some tips for training them successfully. We train two MP-SAEs at different L0s on Gemma-2-2b, and evaluate them against BatchTopK and Matryoshka SAEs that have the same L0 as the MP-SAEs. All SAEs trained as part of this post are available at huggingface.co/chanind/gemma-2-2b-layer-12-matching-pursuit-comparison and can be loaded using SAELens.
My main takeaway is that while MP-SAEs are exciting for researchers working on improving SAEs, I would not recommend them for practical use in LLM interpretability; or at least, it shouldn't be the first thing you try. MP-SAEs outperform traditional SAEs at reconstruction, but I do not see evidence that this results in a better SAE for practical tasks, and they are slower to train and run than traditional SAEs. MP-SAEs also seem to suffer more from feature absorption than traditional SAEs, likely due to their more expressive encoder. That being said, these is just my thoughts after training a few MP-SAEs on Gemma-2-2b, and this is not a rigorous analysis.
Regardless, I think MP-SAEs are a great addition to the set of SAE training techniques, and are especially exciting as a future research direction. In general, I am very supportive of finding ways to bring more traditional dictionary learning techniques to the SAE / interpretability world.
What is a Matching Pursuit Encoder?
An MP-SAE can be thought of as a tied TopK SAE, where the K latents are selected in serial rather than in parallel, and the K is dynamic per sample. At each iteration of the algorithm, the latent with the highest dot product with the reconstruction residual is selected, and the latent is projected out of the residual. This is repeated until the reconstruction error of the SAE is below
residual_threshold, or the SAE selects the same latent multiple times. In SAELens, we add an additional stopping condition,max_iterations, to cap the worst-case runtime of the matching pursuit algorithm.Training MP-SAEs on LLMs (in a reasonable amount of time)
For the LLM experiments in this post, I trained MP-SAEs on Gemma-2-2b layer 12. Each SAE has 32k width and is trained on 300M tokens from The Pile. The key difficulty training MP-SAEs is that training can be extremely slow. The serial nature of matching pursuit does not mesh well with training on GPUs, since GPUs are optimized for parallel, not serial, computations. Furthermore, each iteration of the matching pursuit algorithm uses as much compute as a full
sae.encode()call in a traditional SAE. The more iterations that are required to encode a batch of activations, the slower the MP-SAE is. For instance, I found that if I do not setmax_iterationsandresidual_threshold, MP-SAEs can easily take 100+ hours to train on an Nvidia H100 GPU (compared with ~2 hours for a comparable traditional SAE)!I trained two MP-SAEs, a lower-L0 MP-SAE with
residual_threshold=50, max_iterations=300, and a higher-L0 MP-SAE withresidual_threshold=30, max_iterations=400. The lower-L0 SAE ends up with L0 ≈ 85, and the higher-L0 SAE ends up with L0 ≈ 265. SAELens also has an option,stop_on_duplicate_support, that can be set toFalseto turn the MP-SAE into a true "serial TopK" SAE, where the SAE will always runmax_iterationsiterations for every sample. In the rest of this post, I refer to this as a "static" MP-SAE. I also trained a static L0 variant of an MP-SAE with L0=85. Notably, the static variant is what is implemented by the excellent Overcomplete library. The MP-SAEs trained in this post have the following hyperparameters:To compare with these SAEs, I trained BatchTopK SAEs and BatchTopK Matryoshka SAEs, at both L0=85 and L0=265. The Matryoshka SAEs have inner group sizes of 2048 and 8192. The comparison SAEs are otherwise trained identically to the MP-SAEs (same dataset, same width, same number of tokens, same H100 GPU). Training time for these SAEs is shown below.
The MP-SAEs train much slower than the traditional SAEs due to the serial encoder. ~24 hrs isn't a completely unreasonable amount of time to train an SAE, but it means that it's hard to train a MP-SAE on a large number of tokens (300M tokens is not much, SAEs are often trained on 1B+ tokens) . The training time scales with the
max_iterationsparameter, so the "static" variant with a fixed 85 iterations per sample trains much faster than the other variants. It's also possible that there are more performant implementations of the matching pursuit algorithm that could speed things up. If anyone reading this a PyTorch performance expert, pull requests are welcome!MP-SAEs have impressive reconstruction
To measure reconstruction, I calculated the variance explained for each SAE. Results are split between L0=265 SAEs and L0=85 SAEs since comparing reconstruction is only valid when SAEs have the same L0.
In all cases, the MP-SAEs have better reconstruction than the traditional SAEs, and Matryoshka SAEs have the worst reconstruction. Getting better reconstruction does not necessarily mean the resulting SAE is better for interpretability, however. Gradient descent can find degenerate ways to improve reconstruction at the expense of SAE quality.
Interestingly, the static MP-SAE variant seems to have slightly better reconstruction than the standard MP-SAE despite training more than 3x faster. This a good sign that using the static variant does not harm the resulting SAE.
MP-SAEs underperform at K-Sparse Probing
K-sparse probing is common evaluation of SAE quality. I personally like to use the k-sparse probing tasks from the paper "Are Sparse Autoencoders Useful? A Case Study in Sparse Probing", as it contains over 140 sparse probing datasets to evaluate on (implemented as a pypi library called sae-probes). Below are k=1 and k=16 sparse probing results for all SAEs:
For both k=1 and k=16 sparse probing, all MP-SAEs score worse than the traditional SAEs by a notable margin. This implies that MP-SAEs may be improving reconstruction by finding degenerate solutions rather than by better learning the underlying features of the model.
MP-SAEs seem very susceptible to feature absorption
I was particularly excited to train MP-SAEs on LLMs to see how they perform on the SAEBench feature absorption metric, as the Matching Pursuit SAEs paper motivates the MP-SAE architecture as a way to handle feature hierarchy, and implies that MP-SAEs should solve feature absorption. The SAEBench feature absorption rate is shown for each SAE below:
Sadly, I do not see any evidence that MP-SAEs reduce feature absorption. On the contrary, on the SAEBench absorption metric, MP-SAEs score much worse than traditional SAEs, implying they are actually more susceptible to feature absorption than vanilla SAEs. The Matryoshka SAEs score the best on feature absorption, as is expected since Matryoshka SAEs are explicitly designed to solve absorption.
It's possible that there's something unique about MP-SAEs that makes the SAEBench absorption metric invalid, but I can't think of what it would be (if anyone finds an error, please let me know!). However, scoring poorly on feature absorption is consistent with the results above showing that MP-SAEs have better reconstruction than traditional SAEs. Feature absorption can be viewed as a degenerate strategy to improve the reconstruction of the SAE at a given L0, so if MP-SAEs are better able to engage in absorption then we should expect that to result in a higher reconstruction score, which is consistent with what we see.
Final Thoughts
Training MP-SAEs
Prefer Static MP-SAEs
I don't see any downside to using the static variant of MP-SAEs (set
residual_threshold=0, stop_on_duplicate_support=False, and setmax_iterationsto the target L0 of the SAE). This dramatically speeds up the training time of the MP-SAE and does not seem to result in an obviously worse SAE. This is also the version used by the Overcomplete library.Should latents be forced to have unit norm?
In the SAELens MP-SAE implementation, we initialize the decoder to have unit norm but do not enforce this throughout training. This is based on the MP-SAEs reference implementation, which also does not enforce unit norm latents during training.
However, it seems like for the lower-L0 MP-SAEs, the decoder norm drops below 1.0:
Does this indicate the SAE is finding a degenerate way to improve reconstruction loss by somehow intentionally using latents below unit norm? Or is this a valid way to avoid superposition noise? Should we enforce that the decoder must have unit norm throughout training?
Dead latents
I was surprised to find there were no dead latents in any of the MP-SAE runs, despite not having any auxiliary loss to avoid dead latents. I'm not sure if this would still be the case if the SAE was much wider (e.g. 100k+ latents). If you train a very wide MP-SAE and find that there are dead latents, it may be necessary to add an aux loss to training.
Why no SCR/TPP evals?
I also tried running the SAEBench SCR and TPP evals, but found they were too slow to be practical for MP-SAEs. It seems like these evals assume that the SAE encode method is very fast, so these benchmarks probably need to optimized to run on MP-SAEs in a reasonable amount of time. I didn't dig into this, but there's likely some easy optimizations available to enable these benchmarks to run on MP-SAEs if someone wants to look into that.
What do MP-SAEs learn?
I did not try to figure out if the features learned by MP-SAEs and traditional SAEs are different, but I would expect there are meaningful differences. I would be particularly curious if MP-SAEs learn more and/or different high-frequency latents than traditional SAEs. I would also be curious if they behave differently in the presence of feature manifolds to traditional SAEs.
Should you train a MP-SAE?
Based on this investigation, I would not recommend using MP-SAEs if your goal is to use SAEs for interpretability work, or at least it shouldn't be the first thing you try. BatchTopK/JumpReLU seems like a better choice in terms of training time and practical performance. Matryoshka BatchTopK SAEs are also a great choice although there are more hyperparameters to set.
If you are a researcher working on improving SAE architectures, then I think MP-SAEs are very exciting, as the MP-SAE encoder works in a fundamentally different way than traditional SAEs. It may be possible to create some sort of hybrid between a MP-SAE and a standard SAE that mixes the benefits of both architectures, for example, or maybe it's possible to create a Matryoshka MP-SAE to deal with feature absorption.
Just give me the SAEs
All the SAEs in this post are available at https://huggingface.co/chanind/gemma-2-2b-layer-12-matching-pursuit-comparison. These SAEs can be loaded with SAELens v6.26.0+ as follows:
For the other SAEs, replace "matching-pursuit/l0-85" with the path to the SAE in the repo. Each SAE on Huggingface also includes the
runner_cfg.jsonused to train the SAE if you want to see exactly what training settings were used.Try training MP-SAEs!
SAELens v6.26.0 now supports training and running Matching Pursuit SAEs. Give it a try! Also check out the Matching Pursuit SAEs paper "From Flat to Hierarchical: Extracting Sparse Representations with Matching Pursuit".