Stefan Heimersheim. Research Scientist at Apollo Research, Mechanistic Interpretability. The opinions expressed here are my own and do not necessarily reflect the views of my employer.
If we imagine that the meaning is given not by the dimensions of the space but rather by regions/points/volumes of the space
I think this is what I care about finding out. If you're right this is indeed not surprising nor an issue, but you being right would be a major departure from the current mainstream interpretability paradigm(?).
The question of regions vs compositionality is what I've been investigating with my mentees recently, and pretty keen on. I'll want to write up my current thoughts on this topic sometime soon.
What do you mean you’re encoding/decoding like normal but using the k means vectors?
So I do something like
latents_tmp = torch.einsum("bd,nd->bn", data, centroids)
max_latent = latents_tmp.argmax(dim=-1) # shape: [batch]
latents = one_hot(max_latent)
where the first line is essentially an SAE embedding (and centroids are the features), and the second/third line is a top-k. And for reconstruction do something like
recon = centroids @ latents
which should also be equivalent.
Shouldn’t the SAE training process for a top k SAE with k = 1 find these vectors then?
Yes I would expect an optimal k=1 top-k SAE to find exactly that solution. Confused why k=20 top-k SAEs to so badly then.
If this is a crux then a quick way to prove this would be for me to write down encoder/decoder weights and throw them into a standard SAE code. I haven't done this yet.
I'm not sure what you mean by "K-means clustering baseline (with K=1)". I would think the K in K-means stands for the number of means you use, so with K=1, you're just taking the mean direction of the weights. I would expect this to explain maybe 50% of the variance (or less), not 90% of the variance.
Thanks for pointing this out! I confused nomenclature, will fix!
Edit: Fixed now. I confused
this seems concerning.
I feel like my post appears overly dramatic; I'm not very surprised and don't consider this the strongest evidence against SAEs. It's an experiment I ran a while ago and it hasn't changed my (somewhat SAE-sceptic) stance much.
But this is me having seen a bunch of other weird SAE behaviours (pre-activation distributions are not the way you'd expect from the superposition hypothesis h/t @jake_mendel, if you feed SAE-reconstructed activations back into the encoder the SAE goes nuts, stuff mentioned in recent Apollo papers, ...).
Reasons this could be less concerning that it looks
TL,DR: K-means explains about as much (or more) variance in the activations as SAEs do.
Edit: Epistemic status: This is a weekend-experiment I ran a while ago and I figured I should write it up to share. I have taken decent care to check my code for silly mistakes and "shooting myself in the foot", but these results are not vetted to the standard of a top-level post / paper.
SAEs explain most of the variance in activations. Is this alone a sign that activations are structured in an SAE-friendly way, i.e. that activations are indeed a composition of sparse features like the superposition hypothesis suggests?
I'm asking myself this questions since I initially considered this as pretty solid evidence: SAEs do a pretty impressive job compressing 512 dimensions into ~100 latents, this ought to mean something, right?
But maybe all SAEs are doing is "dataset clustering" (the data is cluster-y and SAEs exploit this)---then a different sensible clustering method should also be able do perform similarly well!
I took this SAE graph from Neuronpedia, and added a K-means clustering baseline. Think of this as pretty equivalent to a top-k SAE (with k=1). In fact I use the K-means algorithm to find "feature vectors" and then encode / decode the activations just like I would in an SAE (I'm not using the (non-linear) "predict" method of K-means).
It turns out that even clustering (essentially L_0=1) explains up to 90% of the variance in activations, being matched only by SAEs with L_0>100. This isn't an entirely fair comparison, since SAEs are optimised for the large-L_0 regime, while I haven't found a L_0>1 operationalisation of clustering that meaningfully improves over L_0=1. To have some comparison I'm adding a PCA + Clustering baseline where I apply a PCA before doing the clustering. It does roughly as well as expected, exceeding the SAE reconstruction for most L0 values. The SAEBench upcoming paper also does a PCA baseline so I won't discuss PCA in detail here.
Here's the result for layer 4, 16k latents. See below for plots with layer 3, and/or with 4k latents. (These were the 4 SAEBench suites available on Neuronpedia.)
What about interpretability? Clusters seem "monosemantic" on a skim. In an informal investigation I looked at max-activating dataset examples, and they seem to correspond to related contexts / words like monosemantic SAE features tend to do. I haven't spent much time looking into this though.
Here's the code used to get the clustering & PCA below; the SAE numbers are taken straight from Neuronpedia. Both my code and SAEBench/Neuronpedia use OpenWebText with 128 tokens context length so I hope the numbers are comparable, but there's a risk I missed something and we're comparing apples to oranges.
A final caveat I want to mention is that I think the SAEs I'm comparing here (SAEBench suite for Pythia-70M) are maybe weak. They're only using 4k and 16k latents, for 512 embedding dimensions, using expansion ratios of 8 and 32, respectively (the best SAEs I could find for a ~100M model). But I also limit the number of clusters to the same numbers, so I don't necessarily expect the balance to change qualitatively at higher expansion ratios.
I want to thank @Adam Karvonen, @Lucius Bushnaq, @jake_mendel, and @Patrick Leask for feedback on early results, and @Johnny Lin for implementing an export feature on Neuronpedia for me! I also learned that @scasper proposed something similar here (though I didn't know about it), I'm excited for follow-ups implementing some of Stephen's advanced ideas (HAC, a probabilistic alg, ...).
I’ve just read the article, and found it indeed very thought provoking, and I will be thinking more about it in the days to come.
One thing though I kept thinking: Why doesn’t the article mention AI Safety research much?
In the passage
The only policy that AI Doomers mostly agree on is that AI development should be slowed down somehow, in order to “buy time.”
I was thinking: surely most people would agree on policies like “Do more research into AI alignment” / “Spend more money on AI Notkilleveryoneism research”?
In general the article frames the policy to “buy time” as to wait for more competent governments or humans, while I find it plausible that progress in AI alignment research could outweigh that effect.
—
I suppose the article is primarily concerned with AGI and ASI, and in that matter I see much less research progress than in more prosaic fields.
That being said, I believe that research into questions like “When do Chatbots scheme?”, “Do models have internal goals?”, “How can we understand the computation inside a neural network?” will make us less likely to die in the next decades.
Then, current rationalist / EA policy goals (including but lot limited to pauses and slow downs of capabilities research) could have a positive impact via the “do more (selective) research” path as well.
Thanks for writing these up! I liked that you showed equivalent examples in different libraries, and included the plain “from scratch” version.
Hmm, I think I don't fully understand your post. Let me summarize what I get, and what is confusing me:
I'm confused whether your post tries to tell us (how to determine) what loss our interpretation should recover, or whether you're describing how to measure whether our interpretation recovers that loss (via constructing the M_c models).
You first introduce the SLT argument that tells us which loss scale to choose (the "Watanabe scale", derived from the Watanabe critical temperature).
And then a second (?) scale, the "natural" scale. That loss scale is the different between the given model (Claude 2), and a hypothetical near-perfect model (Claude 5).
Then there's the second part, where you discuss how to obtain a model M_c* corresonding to a desired loss L_c*. There's many ways to do this (trivially: Just walk a straight line in parameter space until the loss reaches the desired level) but you suggest a specific one (Langevin SGD). You suggest that one because it produces a model implementing a "maximally general algorithm" [1] (with the desired loss, and in the same basin). This makes sense if I were trying to interpret / reverse engineer / decompose M_c*, but I'm running my interpretability technique on M_c, right? I believe I have missed why we bother with creating the intermediate M_c model. (I assume it's not merely to find the equivalent parameter count / Claude generation.)
[1] Regarding the "maximally general" claim: You have made a good argument that generalization to memorization is a spectrum (e.g. knowing which city is where on the globe, memorizing grammar roles, all seem kinda ambiguous). So "maximally general" seems not uniquely defined (e.g. a model that has some really general and some really memorized circuits, vs a model that has lots of middle-spectrum circuits).
Great read! I think you explained well the intuition why logits / logprobs are so natural (I haven't managed to do this well in a past attempt). I like the suggestion that (a) NNs consist of parallel mechanisms to get the answer, and (b) the best way to combine multiple predictions is via adding logprobs.
I haven't grokked your loss scales explanation (the "interpretability insights" section) without reading your other post though.
Tentatively I get similar results (70-85% variance explained) for random data -- I haven't checked that code at all though, don't trust this. Will double check this tomorrow.(In that case SAE's performance would also be unsurprising I suppose)