StefanHex

Stefan Heimersheim. Research Scientist at Apollo Research, Mechanistic Interpretability. The opinions expressed here are my own and do not necessarily reflect the views of my employer.

Wiki Contributions

Comments

Sorted by

List of some larger mech interp project ideas (see also: short and medium-sized ideas). Feel encouraged to leave thoughts in the replies below!

Edit: My mentoring doc has more-detailed write-ups of some projects. Let me know if you're interested!

What is going on with activation plateaus: Transformer activations space seems to be made up of discrete regions, each corresponding to a certain output distribution. Most activations within a region lead to the same output, and the output changes sharply when you move from one region to another. The boundaries seem to correspond to bunched-up ReLU boundaries as predicted by grokking work. This feels confusing. Are LLMs just classifiers with finitely many output states? How does this square with the linear representation hypothesis, the success of activation steering, logit lens etc.? It doesn't seem in obvious conflict, but it feels like we're missing the theory that explains everything. Concrete project ideas:

  1. Can we in fact find these discrete output states? Of course we expect thee to be a huge number, but maybe if we restrict the data distribution very much (a limited kind of sentence like "person being described by an adjective") we are in a regime with <1000 discrete output states. Then we could use clustering (K-means and such) on the model output, and see if the cluster assignments we find map to activation plateaus in model activations. We could also use a tiny model with hopefully less regions, but Jett found regions to be crisper in larger models.
  2. How do regions/boundaries evolve through layers? Is it more like additional layers split regions in half, or like additional layers sharpen regions?
  3. What's the connection to the grokking literature (as the one mentioned above)?
  4. Can we connect this to our notion of features in activation space? To some extent "features" are defined by how the model acts on them, so these activation regions should be connected.
  5. Investigate how steering / linear representations look like through the activation plateau lens. On the one hand we expect adding a steering vector to smoothly change model output, on the other hand the steering we did here to find activation plateaus looks very non-smooth.
  6. If in fact it doesn't matter to the model where in an activation plateau an activation lies, would end-to-end SAEs map all activations from a plateau to a single point? (Anecdotally we observed activations to mostly cluster in the centre of activation plateaus so I'm a bit worried other activations will just be out of distribution.) (But then we can generate points within a plateau by just running similar prompts through a model.)
  7. We haven't managed to make synthetic activations that match the activation plateaus observed around real activations. Can we think of other ways to try? (Maybe also let's make this an interpretability challenge?)

Use sensitive directions to find features: Can we use the sensitivity of directions as a way to find the "true features", some canonical basis of features? In a recent post we found current SAE features to look less special that expected, so I'm a bit cautious about this. But especially after working on some toy models about computation in superposition I'd be keen to explore the error correction predictions made here (paper, comment).

Test of we can fully sparsify a small model: Try the full pipeline of training SAEs everywhere, or training Transcoders & Attention SAEs, and doing all that such that connections between features are sparse (such that every feature only interacts with a few other features). The reason we want that is so that we can have simple computational graphs, and find simple circuits that explain model behaviour.

I expect that---absent of SAE improvements finding the "true feature" basis---you'll need to train them all together with a penalty for the sparsity of interactions. To be concrete, an inefficient thing you could do is the following: Train SAEs on every residual stream layer, with a loss term that L1 penalises interactions between adjacent SAE features. This is hard/inefficient because the matrix of SAE interactions is huge, plus you probably need attributions to get these interactions which are expensive to compute (at every training step!). I think the main question for this project is to figure out whether there is a way to do this thing efficiently. Talk to Logan Smith, Callum McDoughall, and I expect there are a couple more people who are trying something like this.

List of some medium-sized mech interp project ideas (see also: shorter and longer ideas). Feel encouraged to leave thoughts in the replies below!

Edit: My mentoring doc has more-detailed write-ups of some projects. Let me know if you're interested!

Toy model of Computation in Superposition: The toy model of computation in superposition (CIS; Circuits-in-Sup, Comp-in-Sup post / paper) describes a way in which NNs could perform computation in superposition, rather than just storing information in superposition (TMS). It would be good to have some actually trained models that do this, in order (1) to check whether NNs learn this algorithm or a different one, and (2) to test whether decomposition methods handle this well.

This could be, in the simplest form, just some kind of non-trivial memorisation model, or AND-gate model. Just make sure that the task does in fact require computation, and cannot be solved without the computation. A more flashy versions could be a network trained to do MNIST and FashionMNIST at the same time, though this would be more useful for goal (2).

Transcoder clustering: Transcoders are a sparse dictionary learning method that e.g. replaces an MLP with an SAE-like sparse computation (basically an SAE but not mapping activations to itself but to the next layer).  If the above model of computation / circuits in superposition is correct (every computation using multiple ReLUs for redundancy) then the transcoder latents belonging to one computation should co-activate. Thus it should be possible to use clustering of transcoder activation patterns to find meaningful model components (circuits in the circuits-in-superposition model). (Idea suggested by @Lucius Bushnaq, mistakes are mine!) There's two ways to do this project:

  1. Train a toy model of circuits in superposition (see project above), train a transcoder, cluster latent activations, and see if we can recover the individual circuits.
  2. Or just try to cluster latent activations in an LLM transcoder, either existing (e.g. TinyModel) or trained on an LLM, and see if the clusters make any sense.

Investigating / removing LayerNorm (LN): For GPT2-small I showed that you can remove LN layers gradually while fine-tuning without loosing much model performance (workshop paper, code, model). There are three directions that I want to follow-up on this project.

  1. Can we use this to find out which tasks the model did use LN for? Are there prompts for which the noLN model is systematically worse than a model with LN? If so, can we understand how the LN acts mechanistically?
  2. The second direction for this project is to check whether this result is real and scales. I'm uncertain about (i) given that training GPT2-small is possible in a few (10?) GPU-hours, does my method actually require on the order of training compute? Or can it be much more efficient (I have barely tried to make it efficient so far)? This project could demonstrate that the removing LayerNorm process is tractable on a larger model (~Gemma-2-2B?), or that it can be done much faster on GPT2-small, something on the order of O(10) GPU-minutes.
  3. Finally, how much did the model weights change? Do SAEs still work? If it changed a lot, are there ways we can avoid this change (e.g. do the same process but add a loss to keep the SAEs working)?

List of some short mech interp project ideas (see also: medium-sized and longer ideas). Feel encouraged to leave thoughts in the replies below!

Edit: My mentoring doc has more-detailed write-ups of some projects. Let me know if you're interested!

Directly testing the linear representation hypothesis by making up a couple of prompts which contain a few concepts to various degrees and test

  • Does the model indeed represent intensity as magnitude? Or are there separate features for separately intense versions of a concept? Finding the right prompts is tricky, e.g. it makes sense that friendship and love are different features, but maybe "my favourite coffee shop" vs "a coffee shop I like" are different intensities of the same concept
  • Do unions of concepts indeed represent addition in vector space? I.e. is the representation of "A and B" vector_A + vector_B? I wonder if there's a way you can generate a big synthetic dataset here, e.g. variations of "the soft green sofa" -> "the [texture] [colour] [furniture]", and do some statistical check.

Mostly I expect this to come out positive, and not to be a big update, but seems cheap to check.

SAEs vs Clustering: How much better are SAEs than (other) clustering algorithms? Previously I worried that SAEs are "just" finding the data structure, rather than features of the model. I think we could try to rule out some "dataset clustering" hypotheses by testing how much structure there is in the dataset of activations that one can explain with generic clustering methods. Will we get 50%, 90%, 99% variance explained?

I think a second spin on this direction is to look at "interpretability" / "mono-semanticity" of such non-SAE clustering methods. Do clusters appear similarly interpretable? I This would address the concern that many things look interpretable, and we shouldn't be surprised by SAE directions looking interpretable. (Related: Szegedy et al., 2013 look at random directions in an MNIST network and find them to look interpretable.)

Activation steering vs prompting: I've heard the view that "activation steering is just fancy prompting" which I don't endorse in its strong form (e.g. I expect it to be much harder for the model to ignore activation steering than to ignore prompt instructions). However, it would be nice to have a prompting-baseline for e.g. "Golden Gate Claude". What if I insert a "<system> Remember, you're obsessed with the Golden Gate bridge" after every chat message? I think this project would work even without the steering comparison actually.

CLDR (Cross-layer distributed representation): I don't think Lee has written his up anywhere yet so I've removed this for now.

Also, just wanted to flag that the links on 'this picture' and 'motivation image' don't currently work.

Thanks for the flag! It's these two images, I realize now that they don't seem to have direct links

Images taken from AMFTC and Crosscoders by Anthropic.

Thanks for the comment!

I think this is what most mech interp researchers more or less think. Though I definitely expect many researchers would disagree with individual points, nor does it fairly weigh all views and aspects (it's very biased towards "people I talk to"). (Also this is in no way an Apollo / Apollo interp team statement, just my personal view.)

Thanks! You're right, totally mixed up local and dense / distributed. Decided to just leave out that terminology

StefanHex114

Why I'm not too worried about architecture-dependent mech interp methods:

I've heard people argue that we should develop mechanistic interpretability methods that can be applied to any architecture. While this is certainly a nice-to-have, and maybe a sign that a method is principled, I don't think this criterion itself is important.

I think that the biggest hurdle for interpretability is to understand any AI that produces advanced language (>=GPT2 level). We don't know how to write a non-ML program that speaks English, let alone reason, and we have no idea how GPT2 does it. I expect that doing this the first time is going to be significantly harder, than doing this the 2nd time. Kind of how "understand an Alien mind" is much harder than "understand the 2nd Alien mind".

Edit: Understanding an image model (say Inception V1 CNN) does feel like a significant step down, in the sense that these models feel significantly less "smart" and capable than LLMs.

Why I'm not that hopeful about mech interp on TinyStories models:

Some of the TinyStories models are open source, and manage to output sensible language while being tiny (say 64dim embedding, 8 layers). Maybe it'd be great to try and thoroughly understand one of those?

I am worried that those models simply implement a bunch of bigrams and trigrams, and that all their performance can be explained by boring statistics & heuristics. Thus we would not learn much from fully understanding such a model. Evidence for this is that the 1-layer variant, which due to it's size can only implement bigrams & trigram-ish things, achieves a better loss than many of the tall smaller models (Figure 4). Thus it seems not implausible that most if not all of the performance of all the models could be explained by similarly simple mechanisms.

Folk wisdom is that the TinyStories dataset is just very formulaic and simple, and therefore models without any sophisticated methods can appear to produce sensible language. I haven't looked into this enough to understand whether e.g. TinyStories V2 (used by TinyModel) is sufficiently good to dispel this worry.

StefanHex39-1

Collection of some mech interp knowledge about transformers:

Writing up folk wisdom & recent results, mostly for mentees and as a link to send to people. Aimed at people who are already a bit familiar with mech interp. I've just quickly written down what came to my head, and may have missed or misrepresented some things. In particular, the last point is very brief and deserves a much more expanded comment at some point. The opinions expressed here are my own and do not necessarily reflect the views of Apollo Research.

Transformers take in a sequence of tokens, and return logprob predictions for the next token. We think it works like this:

  1. Activations represent a sum of feature directions, each direction representing to some semantic concept. The magnitude of directions corresponds to the strength or importance of the concept.
    1. These features may be 1-dimensional, but maybe multi-dimensional features make sense too. We can either allow for multi-dimensional features (e.g. circle of days of the week), acknowledge that the relative directions of feature embeddings matter (e.g. considering days of the week individual features but span a circle), or both. See also Jake Mendel's post.
    2. The concepts may be "linearly" encoded, in the sense that two concepts A and B being present (say with strengths α and β) are represented as α*vector_A + β*vector_B). This is the key assumption of linear representation hypothesis. See Chris Olah & Adam Jermyn but also Lewis Smith.
  2. The residual stream of a transformer stores information the model needs later. Attention and MLP layers read from and write to this residual stream. Think of it as a kind of "shared memory", with this picture in your head, from Anthropic's famous AMFTC.
    1. This residual stream seems to slowly accumulate information throughout the forward pass, as suggested by LogitLens.
    2. Additionally, we expect there to be internally-relevant information inside the residual stream, such as whether the sequence of nouns in a sentence is ABBA or BABA.
    3. Maybe think of each transformer block / layer as doing a serial step of computation. Though note that layers don't need to be privileged points between computational steps, a computation can be spread out over layers (see Anthropic's Crosscoder motivation)
  3. Superposition. There can be more features than dimensions in the vector space, corresponding to almost-orthogonal directions. Established in Anthropic's TMS. You can have a mix as well. See Chris Olah's post on distributed representations for a nice write-up.
    1. Superposition requires sparsity, i.e. that only few features are active at a time.
  4. The model starts with token (and positional) embeddings.
    1. We think token embeddings mostly store features that might be relevant about a given token (e.g. words in which it occurs and what concepts they represent). The meaning of a token depends a lot on context.
    2. We think positional embeddings are pretty simple (in GPT2-small, but likely also other models). In GPT2-small they appear to encode ~4 dimensions worth of positional information, consisting of "is this the first token", "how late in the sequence is it", plus two sinusoidal directions. The latter three create a helix.
      1. PS: If you try to train an SAE on the full embedding you'll find this helix split up into segments ("buckets") as individual features (e.g. here). Pay attention to this bucket-ing as a sign of compositional representation.
  5. The overall Transformer computation is said to start with detokenization: accumulating context and converting the pure token representation into a context-aware representation of the meaning of the text. Early layers in models often behave differently from the rest. Lad et al. claim three more distinct stages but that's not consensus.
  6. There's a couple of common motifs we see in LLM internals, such as
    1. LLMs implementing human-interpretable algorithms.
      1. Induction heads (paper, good illustration): attention heads being used to repeat sequences seen previously in context. This can reach from literally repeating text to maybe being generally responsible for in-context learning.
      2. Indirect object identification, docstring completion. Importantly don't take these early circuits works to mean "we actually found the circuit in the model" but rather take away "here is a way you could implement this algorithm in a transformer" and maybe the real implementation looks something like it.
        1. In general we don't think this manual analysis scales to big models (see e.g. Tom Lieberum's paper)
        2. Also we want to automate the process, e.g. ACDC and follow-ups (1, 2).
        3. My personal take is that all circuits analysis is currently not promising because circuits are not crisp. With this I mean the observation that a few distinct components don't seem to be sufficient to explain a behaviour, and you need to add more and more components, slowly explaining more and more performance. This clearly points towards us not using the right units to decompose the model. Thus, model decomposition is the major area of mech interp research right now.
    2. Moving information. Information is moved around in the residual stream, from one token position to another. This is what we see in typical residual stream patching experiments, e.g. here.
    3. Information storage. Early work (e.g. Mor Geva) suggests that MLPs can store information as key-value memories; generally folk wisdom is that MLPs store facts. However, those facts seem to be distributed and non-trivial to localise (see ROME & follow-ups, e.g. MEMIT). The DeepMind mech interp team tried and wasn't super happy with their results.
    4. Logical gates. We think models calculate new features from existing features by computing e.g. AND and OR gates. Here we show a bunch of features that look like that is happening, and the papers by Hoagy Cunningham & Sam Marks show computational graphs for some example features.
    5. Activation size & layer norm. GPT2-style transformers have a layer normalization layer before every Attn and MLP block. Also, the norm of activations grows throughout the forward pass. Combined this means old features become less important over time, Alex Turner has thoughts on this.
      1. There are hypotheses on what layer norm could be responsible for, but it can't do anything substantial since you can run models without it (e.g. TinyModel, GPT2_noLN)
  7. (Sparse) circuits agenda. The current mainstream agenda in mech interp (see e.g. Chris Olah's recent talk) is to (1) find the right components to decompose model activations, to (2) understand the interactions between these features, and to finally (3) understand the full model.
    1. The first big open problem is how to do this decomposition correctly. There's plenty of evidence that the current Sparse Autoencoders (SAEs) don't give us the correct solution, as well as conceptual issues. I'll not go into the details here to keep this short-ish.
    2. The second big open problem is that the interactions, by default, don't seem sparse. This is expected if there are multiple ways (e.g. SAE sizes) to decompose a layer, and adjacent layers aren't decomposed correspondingly. In practice this means that one SAE feature seems to affect many many SAE features in the next layers, more than we can easily understand. Plus, those interactions seem to be not crisp which leads to the same issue as described above.

Thanks for the nice writeup! I'm confused about why you can get away without interpretation of what the model components are:

In cases where we worry that our model learned a human-simulator / camera-simulator rather than actually predicting whether the diamond exists, wouldn't circuit discovery simply give us the human-simulator circuit? (And thus causal scrubbing doesn't save us.) I'm thinking in particular of cases where the human-simulator is easier to learn than the intended solution.

Of course if you had good interpretability, a way to realise whether your explanation is the human simulator is to look for suspicious human-simulator-related features. I would like to get away without interpretation, but it's not clear to me that this works.

Load More