This is a very good explanation of why SAE's incentivize feature combinatorics. Nice! I hadn't thought about the tradeoff between the MSE-reduction for learning a rare feature & the L1-reduction for learning a common feature combination.
Freezing already learned features to iteratively learn more and more features could work. In concrete details, I think you would:1. Learn an initial SAE w/ a much lower L0 (higher l1-alpha) than normally desired.2. Learn a new SAE to predict the residual of (1), so the MSE would be only on what (1) messed up predicting. The l1 would also only be on this new SAE (since the other is frozen). You would still learn a new decoder-bias which should just be added on to the old one. 3. Combine & repeat until desired losses are obtained
There are at least 3 hyperparameters here to tune:L1-alpha (and do you keep it the same or try to have smaller number of features per iteration?), how many tokens to train on each (& I guess if you should repeat data?), & how many new features to add each iteration.I believe the above should avoid problems. For example, suppose your first iteration perfectly reconstructs a datapoint, then the new SAE is incentivized to have low L1 but not activating at all for those datapoints.
The SAE could learn to represent the true features, A & B, as the left graph, so the orthogonal regularizer would help. When you say the SAE would learn inhibitory weights*, I'm imagining the graph on the right; however, these features are mostly orthogonal to eachother meaning the proposed solution won't work AFAIK.
(Also, would be the regularizer be abs(cos_sim(x,x'))?)
*In this example this is because the encoder would need inhibitory weights to e.g. prevent neuron 1 from activating when both neurons 1 & 2 are present as we will discuss shortly.
One experiment here is to see if specific datapoints that have worse CE-diff correlate across layers. Last time I did a similar experiment, I saw a very long tail of datapoints that were worse off (for just one layer of gpt2-small), but the majority of datapoints had similar CE. So Joseph's suggested before to UMAP these datapoints & color by their CE-diff (or other methods to see if you could separate out these datapoints).
If someone were to run this experiment, I'd also be interested if you removed the k-lowest features per datapoint, checking the new CE & MSE. In the SAE-work, the lowest activating features usually don't make sense for the datapoint. This is to test the hypothesis:
For example, if you saw better CE-diff when removing low-activating features, up to a specific k, then SAE's are looking good!
There's a few things to note. Later layers have:
For (3), we might expect under-normed reconstructions because there's a trade-off between L1 & MSE. After training, however, we can freeze the encoder, locking in the L0, and train on the decoder or scalar multiples of the hidden layer (h/t to Ben Wright for first figuring this out).
(4) Seems like a pretty easy experiment to try to just vary num of features to see if this explains part of the gap.
Correct. So they’re connecting a feature in F2 to a feature in F1.
If you removed the high-frequency features to achieve some L0 norm, X, how much does loss recovered change?
If you increased the l1 penalty to achieve L0 norm X, how does the loss recovered change as well?
Ideally, we can interpret the parts of the model that are doing things, which I'm grounding out as loss recovered in this case.
I've noticed that L0's above 100 (for the Pythia-70M model) is too high, resulting in mostly polysemantic features (though some single-token features were still monosemantic)
Agreed w/ Arthur on the norms of features being the cause of the higher MSE. Here are the L2 norms I got. Input is for residual stream, output is for MLP_out.
I really like this post, but more for:
than updating my own credences on specifics.
I actually do have some publicly hosted, only on residual stream and some simple training code.
I'm wanting to integrate some basic visualizations (and include Antrhopic's tricks) before making a public post on it, but currently:
Dict on pythia-70m-deduped
Dict on Pythia-410m-deduped
Which can be downloaded & interpreted with this notebook
With easy training code for bespoke models here.
This doesn't engage w/ (2) - doing awesome work to attract more researchers to this agenda is counterfactually more useful than directly working on lowering the compute cost now (since others, or yourself, can work on that compute bottleneck later).
Though honestly, if the results ended up in a ~2x speedup, that'd be quite useful for faster feedback loops for myself.