Crossposted from the AI Alignment Forum. May contain more technical jargon than usual.

As more people begin work on interpretability projects which incorporate dictionary learning, it will be valuable to have high-quality dictionaries publicly available.[1] To get the ball rolling on this, my collaborator (Aaron Mueller) and I are:

  • open-sourcing a number of sparse autoencoder dictionaries trained on Pythia-70m MLPs
  • releasing our repository for training these dictionaries[2].

Let's discuss the dictionaries first, and then the repo.

The dictionaries

[EDIT 02/07/2024: Better dictionaries are now available at the repo. Also, the originally reported MSE loss numbers were wrong and been updated in the tables below. (The correct numbers were much lower, i.e. better.)]

The dictionaries can be downloaded from here. See the sections "Downloading our open-source dictionaries" and "Using trained dictionaries" here for information about how to download and use them. If you use these dictionaries in a published paper, we ask that you mention us in the acknowledgements.

We're releasing two sets of dictionaries for EleutherAI's 6-layer pythia-70m-deduped model. The dictionaries in both sets were trained on 512-dimensional MLP output activations (not the MLP hidden layer like Anthropic used), using ~800M tokens from The Pile.

  • The first set, called 0_8192, consists of dictionaries of size . These were trained with an L1 penalty of 1e-3.
  • The second set, called 1_32768, consists of dictionaries of size . These were trained with an l1 penalty of 3e-3.

Here are some statistics. (See our repo's readme for more info on what these statistics mean.)

For dictionaries in the 0_8192 set:

LayerMSE LossL1 lossL0% Alive% Loss Recovered
00.0036.1329.9510.9980.984
10.0086.67744.7390.8870.924
20.01111.4462.1560.5870.867
30.01823.773175.3030.5880.902
40.02227.084174.070.8060.927
50.03247.126235.050.6720.972

For dictionaries in the 1_32768 set:

LayerMSE LossL1 lossL0% Alive% Loss Recovered
00.00184.322.8730.1740.946
10.0172.79811.2560.1590.768
20.0236.15116.3810.1180.724
30.04411.57139.8630.2260.765
40.04813.66529.2350.190.816
50.06926.443.8460.130.931

And here are some histograms of feature frequencies.

Overall, I'd guess that these dictionaries are decent, but not amazing.

We trained these dictionaries because we wanted to work on a downstream application of dictionary learning, but lacked the dictionaries. These dictionaries are more than good enough to get us off the ground on our mainline project, but I expect that in not too long we'll come back to train some better dictionaries (which we'll also open source). I think the same is true for other folks: these dictionaries should be sufficient to get started on projects that require dictionaries; and when better dictionaries are available later, you can swap them in for optimal results.

Some miscellaneous notes about these dictionaries (you can find more in the repo).

  • The later layer dictionaries in 0_8192 have too-high L0s. However, looking at the feature frequency histograms, it looks like this might be because of a spike in high-frequency features. Without this spike, the L0s would be much more reasonable, and features outside of this spike look pretty decent (see here for more).
    • We speculate with very low confidence that these spikes might be an artifact of our timing for resampling dead neurons. We resample every 30000 steps, including at step 90000 out of 100000 total steps. The resampled features tend to be very high-frequency, and it might take more than 10000 steps for the peak to move to the left.
  • The L1 penalty for 1_32768 seems to have been too large; only 10-20% of the neurons are alive, and the loss recovered is much worse. That said, we'll remark that after examining features from both sets of dictionaries, the dictionaries from the 1_32768 set seem to have more interpretable features than those from the 0_8192 set (though it's hard to tell).
    • In particular, we suspect that for 0_8192, the many high-frequency features in the later layers are uninterpretable but help significantly with reconstructing activations, resulting in deceptively good-looking statistics.
  • As we progress through the layers, the dictionaries tend to get worse along most metrics (except for % loss recovered). This may have to do with the growing scale of the activations themselves as one moves through the layers of pythia models (h/t to Arthur Conmy for raising this hypothesis).
  • We note that our dictionary features are significantly higher frequency overall than the features in Anthropic's and Neel Nanda's. We don't know if this difference is because we are working with a multi-layer model or if it is because of a difference in hyperparameters. We generally suspect it would be better if we were learning features of lower frequency.
    • We'll note, however, that after layer 0, it doesn't seem like many of our features are of the form "always fire on a particular token," whereas many of Anthropic's feature were. So it's possible that more interesting features also tend to be higher-frequency. See here for some flavor.

The dictionary learning repository

Again, this can be found here. We followed the approach detailed in Anthropic's paper (including using untied encoder/decoder weights, constraining the decoder vectors to have unit norm, and resampling dead neurons according to their wacky scheme), except for the following:

  • We didn't have the space to store activations for our entire dataset, so – following Neel Nanda's replication – we maintain a buffer of tokens from a few thousand contexts and randomly sample from this buffer until it's half-empty (at which point we refresh it with tokens from new contexts).
  • We used a brief linear learning rate warm-up to fix a problem where Adam would kill too many of our neurons in first few training steps, before it had a chance for the Adam parameters to calibrate.

(A brief plug: this repository is built using nnsight, a new interpretability tooling library (like transformer_lens and baukit) being developed by Jaden Fiotto-Kaufman and others in the Bau lab. nnsight is still under development, so I only recommend trying to dive into it now if you're okay with occasional bugs, memory leaks, etc. (which you can report in the feedback channel of this Discord server). But I'm overall very excited about the project – aside from providing a very clean user experience, one major design goal is that nnsight code is highly portable: you should ideally be able to prototype an experiment with Pythia-70m, switch seamlessly to running it on LLaMA-2-70B split across multiple GPUs, and then ship your code to Anthropic to be run on Claude.)

In addition to the mainline functionality, our repo also supports some experimental features, which we briefly investigated as alternative approaches to training dictionaries:

  • MLP stretchers. Based on the perspective that one may be able to identify features with "neurons in a sufficiently large model" we experimented with training "autoencoders" to, given as input an MLP input activation , output  (the MLP output). For instance, given an MLP which maps a 512-dimensional input  to a 1024-dimensional hidden state  and then a 512-dimensional output , we train a dictionary  with hidden dimension  so that  is close to  (and, as usual, so that the hidden state of the dictionary is sparse).
    • The resulting dictionaries seemed decent, but we decided not to pursue the idea further.
    • (h/t to Max Li for this suggestion.)
  • Replacing L1 loss with entropy. Based on the ideas in this post, we experimented with using entropy to regularize a dictionary's hidden state instead of L1 loss. This seemed to cause the features to either be dead features (which never fired) or very high-frequency features which fired on nearly every input, which was not the desired behavior. But plausibly there is a way to make this work better.

If you want to pursue one of the ideas in the above bullet points, I ask that you get in touch with me (Sam) once you have preliminary results – I may be interested in discussing results or collaborating.

  1. ^

    This is both for the sake of reproducibility, and because each dictionary takes some effort to train.

  2. ^

    Of course, the repository from the Cunningham et al. paper is also available here.

New to LessWrong?

New Comment
7 comments, sorted by Click to highlight new comments since: Today at 1:29 PM

I've noticed that L0's above 100 (for the Pythia-70M model) is too high, resulting in mostly polysemantic features (though some single-token features were still monosemantic)

Agreed w/ Arthur on the norms of features being the cause of the higher MSE. Here are the L2 norms I got. Input is for residual stream, output is for MLP_out.

[-]Sam Marks5moΩ120

I agree that the L0's for 0_8192 are too high in later layers, though I'll note that I think this is mainly due to the cluster of high-frequency features (see the spike in the histogram). Features outside of this spike look pretty decent, and without the spike our L0s would be much more reasonable. 

Here are four random features from layer 3, at a range of frequencies outside of the spike.

Layer 3, 0_8192, feature 138 (frequency = 0.003) activates on the newline at the end of the "field of the invention" section in patent applications. I think it's very likely predicting that the next few tokens will be "2. Description of the Related Art" (which always comes next in patents).

Layer 3, 0_8192, feature 27 (frequency = 0.009) seems to activate on the "is" in the phrase "this is"

Layer 3, 0_8192, feature 4 (frequency = 0.026) looks messy at first, but on closer inspection seems to activate on the final token of multi-token words in informative file/variable names.

Layer 3, 0_8192, feature 56 (frequency = 0.035) looks very polysemantic: it's activating on certain terms in LaTeX expressions, words in between periods in urls and code, and some other random-looking stuff.

If you removed the high-frequency features to achieve some L0 norm, X, how much does loss recovered change? 

If you increased the l1 penalty to achieve L0 norm X, how does the loss recovered change as well?

Ideally, we can interpret the parts of the model that are doing things, which I'm grounding out as loss recovered in this case.

[-]Sam Marks5moΩ580

Here's an experiment I'm about to do:

  • Remove high-frequency features from 0_8192 layer 3 until it has L0 < 40 (the same L0 as the 1_32768 layer 3 dictionary)
  • Recompute statistics for this modified dictionary.

I predict the resulting dictionary will be "like 1_32768 but a bit worse." Concretely, I'm guessing that means % loss recovered around 72%. 

 

Results:

I killed all features of frequency larger than 0.038. This was 2041 features, and resulted in a L0 just below 40. The stats:

MSE Loss: 0.27 (worse than 1_32768)

Percent loss recovered: 77.9% (a little bit better than 1_32768)

I was a bit surprised by this -- it suggests the high-frequency features are disproportionately likely to be useful for reconstructing activations in ways that don't actually mater to the model's computation. (Though then again, maybe this is what we expect for uninterpretable features.)

It also suggests that we might be better off training dictionaries with a too-low L1 penalty and then just pruning away high-frequency features (sort of the dual operation of "train with a high L1 penalty and resample low-frequency features"). I'd be interested for someone to explore if there's a version of this that helps.

Do you apply LR warmup immediately after doing resampling (i.e. immediately reducing the LR, and then slowly increasing it back to the normal value)? In my GELU-1L blog post I found this pretty helpful (in addition to doing LR warmup at the start of training)

At the time that I made this post, no, but this has been implemented in dictionary_learning since I saw your suggestion to do so in your linked post.

Thank you!