Wikitag Contributions

Comments

Sorted by

This plot illustrates how the choice of training and evaluation datasets affects reconstruction quality. Specifically, it shows: 1) Explained variance of hidden states, 2) L2 loss across different training and evaluation datasets, and 3) Downstream CE differences in the language model.

The results indicate that SAEs generalize reasonably well across datasets, with a few notable points:

  1. SAEs trained on TinyStories struggle to reconstruct other datasets, likely due to its synthetic nature.
  2. Web-based datasets (top-left 3x3 subset) perform well on each other, although the CE difference and L2 loss are still 2–3 times higher compared to evaluating on the same dataset. This behavior aligns with expectations but suggests there could be methods to enhance generalizability beyond training separately on each dataset. This is particularly intriguing, given that my team is currently exploring dataset-related effects in SAE training.

Conclusively, the explained variance approaching 1 indicates that even without direct feature matching, the composition of learned features remains consistent across datasets, as hypothesized.
(The code is available in the same repository. results were evaluated on 10k sequences per dataset)

Thank you for your answer. I understand that in the extreme case and in the illustrative example. The 1999-bit value was derived from binary decisions for each token, as you mentioned, while it exceeds the typical DL.

More importantly, optimizing the tradeoff between sparsity and description length is like solving a convex optimization problem. It would be great to formalize this relationship and observe the trend between sparsity (x-axis) and DL (y-axis), although I have no specific approach in mind. My intuition is that the MDL might serve as a lower bound, with the overall behavior being approximated by the dominant factor's information in each regime.

Interesting! This way of finding a desirable dictionary size and sparsity is fascinating. Also, it's intriguing that the MDL incentivizes SAEs to generate hierarchical features rather than feature splitting.

I have some questions regarding the upper-bound DL computation:

One-hot encodings: At the sparse extreme, our dictionary could have a row for each neural activation in the dataset, so =1 and . GPT-2 has a vocab size of 50,257 and the SAEs are trained 128 token sequences. All together this gives DL=13,993 bits per token.

I can easily compute the above two values following your instruction; however, I'm having trouble computing the 13,993-bit value, or perhaps I've missed something. My calculation  results in 1998.98. Could you please clarify which part of my calculation is incorrect? 
Another question is about why the sequence length is considered in the extreme sparsity example. It seems to consider all possible token sequences. Is this intended for a fair comparison since the two examples above consider sequence context within relatively dense vectors?

Thank you for your comment!

Regarding the cross-dataset metric, it is interesting to test how the training dataset applies to different datasets, and I'll share the comparison in the comments after measurement. If the combination of features retains a degree of similarity, contrary to my subset hypothesis above, this might be because there is a diverse combination of feature sets (i.e., basis in feature space), which could be why feature matching is generally lower (ideally, it would be one).

I also observed feature changes over training steps, noting about a 0.7 matching ratio between 1e8 tokens and 4e8 tokens (even though the loss change was not significant during training), indicating a considerable impact. However, due to an insufficient budget to allow convergence in various scenarios, I was unable to include this test in my research. One concern is whether the model will converge to a specific feature set or if there will be oscillatory divergence due to continuous streaming. This certainly seems like an interesting direction for further research.

What happens when we learn Meta-SAE's decoder weights again? Meta-Meta-SAE? 🤔

I can only expect greater lossy decomposition