Interesting idea, I had not considered this approach before!
I'm not sure this would solve feature absorption though. Thinking about the "Starts with E-" and "Elephant" example: if the "Elephant" latent absorbs the "Starts with E-" latent, the "Starts with E-" feature will develop a hole and not activate anymore on the input "elephant". After the latent is absorbed, "Starts with E-" wouldn't be in the list to calculate cumulative losses for that input anymore.
Matryoshka works because it forces the early-indexed latents to reconstruct well using only themselves, whether or not later latents activate. I think this pressure is key to stopping the later-indexed latents from stealing the job of the early-indexed ones.
Although the code has the option to add a L1-penalty, in practice I set the l1_coeff to 0 in all my experiments (see main.py for all hyperparameters).
I haven't actually tried this, but recently heard about focusbuddy.ai, which might be a useful ai assistant in this space.
Great work! I have been working on something very similar and will publish my results here some time next week, but can already give a sneak-peak:
The SAEs here were only trained for 100M tokens (1/3 the TinyStories[11:1] dataset). The language model was trained for 3 epochs on the 300M token TinyStories dataset. It would be good to validate these results with more 'real' language models and train SAEs with much more data.
I can confirm that on Gemma-2-2B Matryoshka SAEs dramatically improve the absorption score on the first-letter task from Chanin et al. as implemented in SAEBench!
Is there a nice way to extend the Matryoshka method to top-k SAEs?
Yes! My experiments with Matryoshka SAEs are using BatchTopK.
Are you planning to continue this line of research? If so, I would be interested to collaborate (or otherwise at least coordinate on not doing duplicate work).
Sing along! https://suno.com/song/35d62e76-eac7-4733-864d-d62104f4bfd0
You might enjoy this classic: https://www.lesswrong.com/posts/9HSwh2mE3tX6xvZ2W/the-pyramid-and-the-garden
Rather than doubling down on a single single-layered decomposition for all activations, why not go with a multi-layered decomposition (ie: some combination of SAE and metaSAE, preferably as unsupervised as possible). Or alternatively, maybe the decomposition that is most useful in each case changes and what we really need is lots of different (somewhat) interpretable decompositions and an ability to quickly work out which is useful in context.
Definitely seems like multiple ways to interpret this work, as also described in SAE feature geometry is outside the superposition hypothesis. Either we need to find other methods and theory that somehow finds more atomic features, or we need to get a more complete picture of what the SAEs are learning at different levels of abstraction and composition.
Both seem important and interesting lines of work to me!
When working with SAE features, I've usually relied on a linear intuition: a feature firing with twice the strength has about twice the "impact" on the model. But while playing with an SAE trained on the final layer I was reminded that the actual direct impact on the relative token probabilities grows exponentially with activation strength. While a feature's additive contribution to the logits is indeed linear with its activation strength, the ratio of probabilities of two competing tokens P(A)/P(B) is equal to the exponent of the logit difference exp(logit(A)−logit(B)).
If we have a feature that boosts logit(A) and not logit(B) and we multiply its activation strength by a factor of 5.0, this doesn't 5x its effect on P(A)/P(B), but rather raises its effect to the 5th power. If this feature caused token A to be three times as likely as token B before, it now makes this token 3^5 = 243 times as likely! This might partly explain why the lower activations for a feature are often less interpretable than the top activations. Their direct impact on the relative token probabilities is exponentially smaller.
Note that this only holds for the direct 'logit lens'-like effect of a feature. This makes this intuition mostly applicable to features in the final layers of a model, as the impact of earlier features is probably mostly modulated by their effect on later layers.