StefanHex

Stefan Heimersheim. Research Scientist at Apollo Research, Mechanistic Interpretability. The opinions expressed here are my own and do not necessarily reflect the views of my employer.

Wikitag Contributions

Comments

Sorted by

The previous lines calculate the ratio (or 1-ratio) stored in the “explained variance” key for every sample/batch. Then in that later quoted line, the list is averaged, I.e. we”re taking the sample average over the ratio. That’s the FVU_B formula.

Let me know if this clears it up or if we’re misunderstanding each other!

I think this is the sum over the vector dimension, but not over the samples. The sum (mean) over samples is taken later in this line which happens after the division

        metrics[f"{metric_name}"] = torch.cat(metric_values).mean().item()

I think this is the sum over the vector dimension, but not over the samples. The sum (mean) over samples is taken later in this line which happens after the division

        metrics[f"{metric_name}"] = torch.cat(metric_values).mean().item()

Edit: And to clarify, my impression is that people think of this as alternative definitions of FVU and you got to pick one, rather than one being right and one being a bug.

Edit2: And I'm in touch with the SAEBench authors about making a PR to change this / add both options (and by extension probably doing the same in SAELens); though I won't mind if anyone else does it!

StefanHex*272

PSA: People use different definitions of "explained variance" / "fraction of variance unexplained" (FVU)

 is the formula I think is sensible; the bottom is simply the variance of the data, and the top is the variance of the residuals. The  indicates the  norm over the dimension of the vector . I believe it matches Wikipedia's definition of FVU and R squared.

 is the formula used by SAELens and SAEBench. It seems less principled, @Lucius Bushnaq and I couldn't think of a nice quantity it corresponds to. I think of it as giving more weight to samples that are close to the mean, kind-of averaging relative reduction in difference rather than absolute.

A third version (h/t @JoshEngels) which computes the FVU for each dimension independently and then averages, but that version is not used in the context we're discussing here.

In my recent comment I had computed my own , and compared it to FVUs from SAEBench (which used ) and obtained nonsense results.

Curiously the two definitions seem to be approximately proportional—below I show the performance of a bunch of SAEs—though for different distributions (here: activations in layer 3 and 4) the ratio differs.[1] Still, this means using  instead of  to compare e.g. different SAEs doesn't make a big difference as long as one is consistent.

Thanks to @JoshEngels for pointing out the difference, and to @Lucius Bushnaq for helpful discussions.

  1. ^

    If a predictor doesn't perform systematically better or worse at points closer to the mean then this makes sense. The denominator changes the relative weight of different samples but this doesn't have any effect beyond noise and a global scale, as long as there is no systematic performance difference.

Same plot but using SAEBench's FVU definition. Matches this Neuronpedia page.

I'm going to update the results in the top-level comment with the corrected data; I'm pasting the original figures here for posterity / understanding the past discussion. Summary of changes:

  1. [Minor] I didn't subtract the mean in the variance calculation. This barely had an effect on the results.
  2. [Major] I used a different definition of "Explained Variance" which caused a pretty large difference

Old (no longer true) text:

It turns out that even clustering (essentially L_0=1) explains up to 90% of the variance in activations, being matched only by SAEs with L_0>100. This isn't an entirely fair comparison, since SAEs are optimised for the large-L_0 regime, while I haven't found a L_0>1 operationalisation of clustering that meaningfully improves over L_0=1. To have some comparison I'm adding a PCA + Clustering baseline where I apply a PCA before doing the clustering. It does roughly as well as expected, exceeding the SAE reconstruction for most L0 values. The SAEBench upcoming paper also does a PCA baseline so I won't discuss PCA in detail here.


[...]

Here's the code used to get the clustering & PCA below; the SAE numbers are taken straight from Neuronpedia. Both my code and SAEBench/Neuronpedia use OpenWebText with 128 tokens context length so I hope the numbers are comparable, but there's a risk I missed something and we're comparing apples to oranges.

StefanHex*40

After adding the mean subtraction, the numbers haven't changed too much actually -- but let me make sure I'm using the correct calculation. I'm gonna follow your and @Adam Karvonen's suggestion of using the SAE bench code and loading my clustering solution as an SAE (this code).

These logs show numbers with the original / corrected explained variance computation; the difference is in the 3-8% range.

v3 (KMeans): Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=4096, variance explained = 0.8887 / 0.8568
v3 (KMeans): Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16384, variance explained = 0.9020 / 0.8740
v3 (KMeans): Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=4096, variance explained = 0.8044 / 0.7197
v3 (KMeans): Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16384, variance explained = 0.8261 / 0.7509
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=4095, n_pca=1, variance explained = 0.8910 / 0.8599
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16383, n_pca=1, variance explained = 0.9041 / 0.8766
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=4094, n_pca=2, variance explained = 0.8948 / 0.8647
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16382, n_pca=2, variance explained = 0.9076 / 0.8812
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=4091, n_pca=5, variance explained = 0.9044 / 0.8770
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16379, n_pca=5, variance explained = 0.9159 / 0.8919
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=4086, n_pca=10, variance explained = 0.9121 / 0.8870
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16374, n_pca=10, variance explained = 0.9232 / 0.9012
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=4076, n_pca=20, variance explained = 0.9209 / 0.8983
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16364, n_pca=20, variance explained = 0.9314 / 0.9118
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=4046, n_pca=50, variance explained = 0.9379 / 0.9202
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16334, n_pca=50, variance explained = 0.9468 / 0.9315
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=3996, n_pca=100, variance explained = 0.9539 / 0.9407
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16284, n_pca=100, variance explained = 0.9611 / 0.9499
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=3896, n_pca=200, variance explained = 0.9721 / 0.9641
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16184, n_pca=200, variance explained = 0.9768 / 0.9702
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=3596, n_pca=500, variance explained = 0.9999 / 0.9998
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=15884, n_pca=500, variance explained = 0.9999 / 0.9999
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=4095, n_pca=1, variance explained = 0.8077 / 0.7245
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16383, n_pca=1, variance explained = 0.8292 / 0.7554
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=4094, n_pca=2, variance explained = 0.8145 / 0.7342
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16382, n_pca=2, variance explained = 0.8350 / 0.7636
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=4091, n_pca=5, variance explained = 0.8244 / 0.7484
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16379, n_pca=5, variance explained = 0.8441 / 0.7767
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=4086, n_pca=10, variance explained = 0.8326 / 0.7602
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16374, n_pca=10, variance explained = 0.8516 / 0.7875
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=4076, n_pca=20, variance explained = 0.8460 / 0.7794
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16364, n_pca=20, variance explained = 0.8637 / 0.8048
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=4046, n_pca=50, variance explained = 0.8735 / 0.8188
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16334, n_pca=50, variance explained = 0.8884 / 0.8401
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=3996, n_pca=100, variance explained = 0.9021 / 0.8598
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16284, n_pca=100, variance explained = 0.9138 / 0.8765
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=3896, n_pca=200, variance explained = 0.9399 / 0.9139
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16184, n_pca=200, variance explained = 0.9473 / 0.9246
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=3596, n_pca=500, variance explained = 0.9997 / 0.9996
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=15884, n_pca=500, variance explained = 0.9998 / 0.9997
StefanHex*20

You're right. I forgot subtracting the mean. Thanks a lot!!

I'm computing new numbers now, but indeed I expect this to explain my result! (Edit: Seems to not change too much)

I should really run a random Gaussian data baseline for this.

Tentatively I get similar results (70-85% variance explained) for random data -- I haven't checked that code at all though, don't trust this. Will double check this tomorrow.

(In that case SAE's performance would also be unsurprising I suppose)

[This comment is no longer endorsed by its author]Reply
Load More