Awesome work! I'd be quite interested to know whether the benefits from this technique are equivalently significant with a larger SAE and also what the original perplexity was (when looking at the summary statistics table). I'll probably reimplement at some point.
Also, kudos on the visualizations. Really love the color scales!
The original perplexity of the LLM was ~38 on the open web text slice I used. Thanks for the compliments!
Thanks for your amazing work! Theoretically I think that layers with higher input norms should have lower SAE L2 ratios, as they corresponds to higher feature activations that are penalized heavier. I wonder if your data confirms this hypothesis.
Produced as part of the ML Alignment Theory Scholars Program - Winter 2023-24 Cohort as part of Lee Sharkey's stream.
TL;DR
Sparse autoencoders are a method of resolving superposition by recovering linearly encoded “features” inside activations. Unfortunately, despite the great recent success of SAEs at extracting human interpretable features, they fail to perfectly reconstruct the activations. For instance, Cunningham et al. (2023) note that replacing the residual stream of layer 2 of Pythia-70m with the reconstructed output of an SAE increased the perplexity of the model on the Pile from 25 to 40. It is important for interpretability that the features we extract accurately represent what the model is doing.
In this post, I show how and why SAEs have a reconstruction gap due to ‘feature suppression’. Then, I look at a few ways to fix this while maintaining SAEs interpretability. By modifying and fine-tuning a pre-trained SAE, we achieve a 9% decrease in mean square error and a 24% reduction in the perplexity increase upon patching activations into the LLM.
Finally, I compare a theoretical example to the observed amounts of feature suppression in Pythia 70m, showing that features are suppressed based on both the strength of their activations and their frequency of activation.
Feature Suppression
The architecture of an SAE is:
f(x)=ReLU(Wex+be)
y=Wdf(x)+bd
The loss function usually combines a MSE reconstruction loss with a sparsity term, like L(x,f(x),y)=||y−x||2/d+c|f(x)|, where d is the dimension of x. When training the SAE on this loss, the decoder’s weight matrix is fixed to have unit norm for each feature (column).
The reason for feature suppression is simple: The training loss has two terms, only one of which is reconstruction. Therefore, reconstruction isn’t perfect. In particular, the loss function pushes for smaller f(x) values, leading to suppressed features and worse reconstruction.
An illustrative example of feature suppression
As an example, consider the trivial case where there is only one binary feature in one dimension. That is, x=1 with probability p and x=0 otherwise. Then, ideally the optimal SAE would extract feature activations of f(x)∈{0,1} and have a decoder with Wd=1.
However, if we were to train an SAE optimizing the loss function L(x,f(x),y)=||y−x||2+c|f(x)|, we get a different result. If we ignore bias terms for simplicity of argument, and say that the encoder outputs feature activation a if x=1 and 0 otherwise, then the optimization problem becomes:
a=argminp∗L(1,a,a)+(1−p)∗L(0,0,0)=argmin(a−1)2+|a|∗c=argmina2+(c−2)∗a+1
⟹a=1−c2
Therefore the feature is scaled by a factor of 1−c/2 compared to optimal. This is an example of feature suppression.
If we allow the ground truth feature to have an activation strength g upon activation and dimension d, this factor becomes:
a=1−cd2g
In other words, instead of having the ground truth activation g, the SAE learns an activation of g−cd2, a constant amount less. Features with activation strengths below cd2 would be completely killed off by the SAE.
Feature suppression is a significant problem in current SAEs
To experimentally verify that feature suppression affects SAEs, we first trained SAEs on the residual stream output of each layer of Pythia-70m with an L1 sparsity penalty (coefficient 2e-3) on 6 epochs of 100 million tokens of OpenWebText, with batch size 64 and learning rate 1e-3, resulting in roughly 13-80 feature activations per token. The residual stream of Pythia-70m had a dimension size of 512 and we used a dictionary size of 2048, for a four times scale up.
If feature suppression had a noticeable effect, we’d see that the SAE reconstructions had noticeably smaller L2 norms than the inputs. Therefore, the expected ratio of the norms (or “L2 Ratio”) can be a proxy for measuring the presence of feature suppression. Indeed, Figure 1 shows feature suppression to be significant, with all but the last layer showing that the reconstructions are less than 90% the norm of the activations on average.
How can we fix feature suppression in trained SAEs?
Here we evaluate a few different methods to fix the problem of feature suppression. The general approach will be to take a normally-trained SAE and then fine-tune a subset of its parameters without the presence of a sparsity penalty. Freezing parts of the SAE is imperative to prevent degradation of the sparse structure of learned features.
We first slightly modify the SAE by introducing an element-wise scaling factor for each feature in-between the encoder and decoder, represented as the vector ss:
f(x)=ReLU(Wex+be)
fs(x)=s⊙f(x)
y=Wdfs(x)+bd
In our experiments, we permit different sets of parameters to train, labeled by a method name as laid out in Table 1.
Table 1: The trainable parameters of each fine-tuning method.
We also include a “Baseline” method which is simply training the original SAE without any changes to architecture or loss function on the same number of tokens. We expect there to be no change in the “baseline” method if the SAE is already trained to convergence.
We fine-tuned the SAEs using each of these methods on the same 100 million token subset of OpenWebText and made no change to other hyperparameters. Most of these methods (except for the “Encoder” method) converged by 20-30 million tokens, which is much faster than the SAE pretraining. As a caveat, it may be worth training for longer than is strictly necessary for convergence if you wish to analyze how extremely rare features are modified by fine-tuning.
Fine-tuning Reduces Feature Suppression
The first thing that stands out about these different fine-tuning approaches was that any method that fine-tuned the encoder led to a drastic reduction in sparsity (Figure 2). This is clearly undesirable, so in future plots we omit methods that fine-tune the encoder (but full versions of the plots are available in the appendix).
We next show that the “Scale”, “Decoder”, and “Unrotated Decoder” methods address the feature suppression issue in Figure 3, which expands Figure 1 to plot the L2 ratio for our other methods. Each fine-tuning method greatly improves the L2 ratio to be closer to 1, as desired, although there is still some gap left in layers 1 through 5. This may be due to missing features in the encoder and other errors that cannot be recovered in the decoder.
Next, we examine how well our fine-tuning methods improve upon the reconstruction loss and the perplexity of the LLM after patching in the activation reconstruction. Figures 4 and 5 use a fractional scale, where a 1 means the fine-tuned SAE has completely recovered the original LLM’s performance, and a 0 means the fine-tuned SAE has done just as poorly as the original SAE.
The “Decoder” method reduces reconstruction loss MSE by 8.76% (Figure 4) and the perplexity gap by 24.11% (Figure 5). MSE and perplexity values in addition to other metrics for each individual layer are also included in the appendix.
Activation Strength Causes Feature Suppression
In this section we analyze how our fine-tuning methods address feature suppression. We look from the lens of feature activation frequency and activation strength.
Taking a look at a graph of empirical feature frequencies in Figure 6, we immediately see an almost discontinuous jump in activation frequencies in the middle for all but the last layer, which has a slightly more gradual change. We found that the features on the infrequent side of the jump were essentially dead features with no significant change upon fine-tuning, and therefore we exclude them from future plots. We also note that the graph looks mostly similar if instead of plotting feature frequencies we plotted the sum of activation values across the dataset.
We focus the rest of our analysis on layer 5, but the results are similar for other layers.
A theoretical example predicts frequency isn't a factor
Our illustrative example predicted that the frequency of a feature’s activation, p, would be independent from the amount that feature gets suppressed in a trained SAE. This is because, in the single binary feature scenario, an SAE would learn activation values according to:
activation=(1−cd2⋅ground truth strength)×ground truth strength=ground truth strength−cd/2
This formula does not include p. Therefore, when adding new features to the theoretical example, as long as each feature has sufficiently small interference with every other feature, an SAE will treat each feature independently and similarly to the single feature scenario.
Moreover, activation strength is a factor when considering the multiplicative scaling factor, 1−cd2⋅ground truth strength, but not when taking the difference between SAE activation strength and ground truth strength −cd2.
Fine-tuning does not fix regression dilution
While our illustrative example didn’t show much of a relationship with frequency, there is a separate way in which frequency can cause features to be poorly represented by an SAE, which is regression dilution:
We believe the frequency by which a feature activates may have a relationship with regression dilution, because of interference with other features in the encoder. The rarer a feature is, the more likely any activation of that feature is due to accidental interference from nearby features. This results in the learned SAE feature being noisy, and therefore regression dilution results in a lower feature activation.
Because regression dilution is an artifact of the encoder, it is not something that our decoder-based fine-tuning methods can fix. Therefore, we still predict no relationship between frequency and empirically observed feature scaling upon fine-tuning.
Experimental measurements agree for activation strength, disagree for frequency
Our illustrative example from earlier predicted that scaling after fine-tuning anti-correlates with activation strength and is independent from frequency; while Figure 7 supports the trend for activation strength, it also shows that frequency is not independent. The correlation between log(activation strength) and scaling was −0.710, while the correlation between log(frequency) and scaling was 0.786, contrary to predictions.
Even when isolating the impact of frequency or activation strength, the other variable still had a large magnitude of correlation with scaling. The partial correlations were −0.501 and 0.652 for log(activation strength) and log(frequency), respectively.
Meanwhile, the illustrative example also predicted that the impact of average activation strength vanishes when we consider the average change in activation strength instead of the learned scaling factor, and the experimental results mostly agree, as plotted in Figure 8. Then, the correlations between log(activation strength), log(frequency) and scaling became -0.078,0.569, respectively, showing that activation strength has significantly reduced in importance. The partial correlations also became 0.378 and 0.646.
Conclusion
In conclusion, training with a sparsity penalty leads to suppression of features, predominantly features with frequent but weak activations. By inserting a scaling factor and fine-tuning the decoder without the sparsity penalty, we can achieve a 9% improvement in reconstruction loss and a 24% recovery of perplexity without changing when features activate. Even with nothing but a scaling factor, which in theory means interpretability should be unaffected, fine-tuning can achieve a 6% improvement in reconstruction loss and a 16% recovery of perplexity.
This work suggests that if we want to insert our SAEs into our models and use their outputs in the forward pass of the model (which we might want to do to study the effects of different features on model internals and outputs), then we should consider fine-tuning the scale of the learned features post-SAE-training, or we should identify some other way to reduce feature suppression.
Appendix
Related Work
Feature suppression is similar to what happens in the Lasso method of linear regression, where there is a L1 penalty on the weights in addition to the squared error reconstruction loss. Lasso stands for “Least Absolute Shrinkage and Selection Operator”; as its name says, Lasso regression tends to both select a sparse set of the input dimensions to use, as well as shrink the usage of those inputs. This effect is well-known, and besides choosing different sparsity penalties, one proposal is retraining the linear regression without Lasso after Lasso selects a sparse set of parameters (Belloni and Chernozhukov, 2009). This approach mirrors our approach to mitigating feature suppression in SAEs.
Lee et al., 2019 also point out how freezing parts of a language model during fine-tuning can be a good idea, which is similar to our motivation for keeping the encoder frozen while fine-tuning SAEs.
Extended Data
We show the full version of Figures 3, 4, and 5, including the methods that allowed parts of the encoder to train.
Summary Statistics for SAEs
Table A1: Metrics for the base SAEs and the fine-tuned SAEs under the “Decoder” method. “L0” measures the average number of feature activations per token, and “L2 Ratio” measures the average ratio of the norm of the output to the norm of the input. “Dead Features” shows how many of the 2048 features never activate on the dataset.
Low strength/high frequency features also rotate more
We measure the cosine similarity between the decoder dictionary vector before and after fine-tuning. Future work could examine what causes rotations in more detail.