Eric J. Michaud

I am a PhD student in the Department of Physics at MIT. I did my undergrad in math at UC Berkeley, and interned at CHAI in 2020. My current research is on the science/theory of deep learning.

Posts

Sorted by New

Wiki Contributions

Comments

Nice post! I'm one of the authors of the Engels et al. paper on circular features in LLMs so just thought I'd share some additional details about our experiments that are relevant to this discussion.

To review, our paper finds circular representations of days of the week and months of the year within LLMs.  These appear to reflect the cyclical structure of our calendar! These plots are for gpt-2-small:

We found these by (1) performing some graph-based clustering on SAE features with the cosine similarity between decoder vectors as the similarity measure. Then (2) given a cluster of SAE features, identify tokens which activate any feature in the cluster and reconstruct the LLM's activation vector on tokens with the SAE while only allowing the SAE features in the cluster to participate in the reconstruction. Then (3) we visualize the reconstructed points along the top PCA components. Our idea here was that if there were some true higher-dimensional features in the LLM, that multiple SAE features would together need to participate in reconstructing that feature, and we want to visualize this feature while removing all others from the activation vector. To find such groups of SAE features, the decoder vector cosine similarity was what worked in practice. We also tried Jaccard similarity (capturing how frequently SAE features fire together) but it didn't yield interesting clusters like cosine similarity did in the experiments we ran.

In practice, this required looking at altogether thousands of panels of interactive PCA plots like this:

Here's a Dropbox link with all 500 of the gpt-2-small interactive plots like these that we looked at: https://www.dropbox.com/scl/fo/usyuem3x4o4l89zbtooqx/ALw2-ZWkRx_I9thXjduZdxE?rlkey=21xkkd6n8ez1n51sf0d773w9t&st=qpz5395r&dl=0 (note that I used n_clusters=1000 with spectral clustering but only made plots for the top 500, ranked by mean pairwise cosine similarity of SAE features within the cluster).

Here are the clusters that I thought might have interesting structure:

  • cluster67: numbers, PCA dim 2 is their value
  • cluster109: money amounts, pca dim 1 might be related to cents and pca dim 2 might be related to the dollar amount
  • cluster134: different number-related tokens like "000", "million" vs. "billion", etc.
  • cluster138: days of the week circle!!!
  • cluster157: years
  • cluster71: possible "left" vs. "right" direction
  • cluster180: "long" vs. "short"
  • cluster212: years, possible circular representation of year within century in pca dims 2-3
  • cluster213: the "-" in between a range of numbers, ordered by the first number
  • cluster223: "up" vs. "down" direction
  • cluster251: months of the year!!
  • cluster285: pca dim 1 is republican vs democrat

You can hover over a point on each scatter plot to see some context and the token (in bold) that the activation vector (residual stream in layer 7) fires above.

Most clusters however don't seem obviously interesting. We also looked at ~2000 Mistral-7B clusters and only the days of the week and months of the year clusters seemed clearly interesting. So at least for the LLMs we looked at, for the SAEs we had, and with our method for discovering interesting geometry, interesting geometry didn't seem ubiquitous. That said, it could just be that our methods are limited, or the LLMs and/or SAEs we used weren't large enough, or that there is interesting geometry but it's not obvious to us from PCA plots like the above.

That said, I think you're right that the basic picture of features as independent near-orthogonal directions from Toy Models of Superposition is wrong, as discussed by Anthropic in their Towards Monosemanticity post, and efforts to understand this could be super important. As mentioned by Joseph Bloom in his comment, understanding this better could inspire different SAE architectures which get around scaling issues we may already be running into.

Huh those batch size and learning rate experiments are pretty interesting!

I checked whether this token character length direction is important to the "newline prediction to maintain text width in line-limited text" behavior of pythia-70m. To review, one of the things that pythia-70m seems to be able to do is to predict newlines in places where a newline correctly breaks the text so that the line length remains approximately constant. Here's an example of some text which I've manually broken periodically so that the lines have roughly the same width. The color of the token corresponds to the probability pythia-70m gave to predicting a newline as that token. Darker blue corresponds to a higher probability. I used CircuitsVis for this:

We can see that at the last couple tokens in most lines, the model starts placing nontrivial probability of a newline occurring there.

I thought that this "number of characters per token" direction would be part of whatever circuit implements this behavior. However, ablating that direction in embedding space seems to have little to no effect on the behavior.  Going the other direction, manually adding this direction to the embeddings seems to not significantly effect the behavior either!



Maybe there are multiple directions representing the length of a token? Here's the colab to reproduce: https://colab.research.google.com/drive/1HNB3NHO7FAPp8sHewnum5HM-aHKfGTP2?usp=sharing

 

Small point/question, Quintin -- when you say that you "can fully avoid grokking on modular arithmetic", in the colab notebook you linked to in that paragraph it looks like you just trained for 3e4 steps. Without explicit regularization, I wouldn't have expected your network to generalize in that time (it might take 1e6 or 1e7 steps for networks to fully generalize). What point were you trying to make there? By "avoid grokking", do you mean (1) avoid generalization or (2) eliminate the time delay between memorization and generalization. I'd be pretty interested if you achieved (2) while not using explicit regularization.

One of the authors of the paper here. Glad you found it interesting! In case people want to mess around with some of our results themselves, here are colab notebooks for reproducing a couple results:

  1. Delaying generalization (inducing grokking) on MNIST: https://colab.research.google.com/drive/1wLkyHadyWiZSwaR0skJ7NypiYKCiM7CR?usp=sharing
  2. Almost eliminating grokking (bringing train and test curves together) in transformers trained on modular addition: https://colab.research.google.com/drive/1NsoM0gao97jqt0gN64KCsomsPoqNlAi4?usp=sharing

Some miscellaneous comments:

  • On some level "just fix your weight norm and the model generalizes" sounds too simple to be true for all tasks -- I agree. I'd be pretty surprised if our result on speeding up generalization on modular arithmetic by constraining weight norm had much relevance to training large language models, for instance. But I haven't thought much about this yet!
  • In terms of relevance to AI safety, I view this work broadly as contributing to a scientific understanding of emergence in ML c.f. "More is Different for AI". It seems useful for us to understand mechanistically how/why surprising capabilities are gained in increasing model scale or training time (as is the case for grokking), so that we can better reason about and anticipate the potential capabilities and risks of future AI systems. Another AI safety angle could lie in trying to unify our observations with Nanda and Lieberum's circuits-based perspective on grokking. My understanding of that work is that networks learn both memorizing and generalizing circuits, and that generalization corresponds to the network eventually "cleaning up" the memorizing circuit, leaving the generalizing circuits. By constraining weight norm, are we just preventing the memorizing circuits from forming? If so, can we learn something about circuits, or auto-discover them, by looking at properties of the loss landscape? In our setup, does switching to polar coordinates factor the parameter space into things which generalize and things which memorize, with the radial direction corresponding to memorization and the angular directions corresponding to generalization? Maybe there are general lessons here.
  • Razied's comment makes a good point about weight L2 norm being a bizarre metric for generalization, since you can take a ReLU network which generalizes and arbitrarily increase its weight norm by multiplying neuron in-weights by  and its out-weights by  without changing the function implemented by the network. The relationship between weight norm and generalization is an imperfect one. What we find empirically is simply this: when we initialize networks in a standard way, multiply all the parameters by , and then constrain optimization to lie on that constant-norm sphere in parameter space, there is often an -dependent gap in test and train performance for the solutions that optimizers find. For large , optimization finds a solution on the sphere which fits the training data but doesn't generalize. For  in the right range, optimization finds a solution on the sphere which does generalize. So maybe the right statement about generalization and weight norm is more about the density of generalizing vs not generalizing solutions in different regions of parameter space, rather than their existence. I'll also point out that this gap between train and test performance as a function of  is often only present when we reduce the size of the training dataset. I don't yet understand mechanistically why this last part is true.