Review

This is the fifth post in the Google DeepMind mechanistic interpretability team’s investigation into how language models recall facts. This post is a bit tangential to the main sequence, and documents some interesting observations about how, in general, early layers of models somewhat (but not fully) specialise into processing recent tokens. You don’t need to believe these results to believe our overall results about facts, but we hope they’re interesting! And likewise you don’t need to read the rest of the sequence to engage with this.

Introduction

In this sequence we’ve presented the multi-token embedding hypothesis, that a crucial mechanism behind factual recall is that on the final token of a multi-token entity there forms an “embedding”, with linear representations of attributes of that entity. We further noticed that this seemed to be most of what early layers did, and that they didn’t seem to respond much to prior context (e.g. adding “Mr Michael Jordan” didn’t substantially change the residual).

We hypothesised the stronger claim that early layers (e.g. the first 10-20%), in general, specialise in local processing, and that the prior context (e.g. more than 10 tokens back) is only brought in in early-mid layers. We note that this is stronger than the multi-token embedding hypothesis in two ways: it’s a statement about how early layers behave on all tokens, not just the final tokens of entities about which facts are known; and it’s a claim that early layers are not also doing longer range stuff in addition to producing the multi-token embedding (e.g. detecting the language of the text). We find this stronger hypothesis plausible, because tokens are a pretty messy input format, and analysing individual tokens in isolation can be highly misleading, e.g. when a long word is split into many fragment tokens, suggesting that longer range processing should be left until some pre-processing on the raw tokens has been done, the idea of detokenization.[1]

We tested this by taking a bunch of arbitrary prompts from the Pile, taking residual streams on those, truncating the prompts to the most recent few tokens and taking residual streams on the truncated prompts, and looking at the mean centred cosine sim at different layers.

Our findings:

  • Early layers do, in general, specialise in local processing, but it’s a soft division of labour not a hard split.
    • There’s a gradual transition where more context is brought in across the layers.
  • Early layers do significant processing on recent tokens, not just the current token - this is not just a trivial result where the residual stream is dominated by the current token and slightly adjusted by each layer
  • Early layers do much more long-range processing on common tokens (punctuation, articles, pronouns, etc)

Experiments

The “early layers specialise in local processing” hypothesis concretely predicts that, for a given token X in a long prompt, if we truncate the prompt to just the most recent few tokens before X, the residual stream at X should be very similar at early layers and dissimilar at later layers. We can test this empirically by looking at the cosine sim of the original vs truncated residual streams, as a function of layer and truncated context length. Taking cosine sims of residual streams naively can be misleading, as there’s often a significant shared mean across all tokens, so we first subtract the mean residual stream across all tokens, and then take the cosine sim.

Set-Up

  • Model: Pythia 2.8B, as in the rest of our investigation
  • Dataset: Strings from the Pile, the Pythia pre-training distribution.
  • Metric: To measure how similar the original and truncated residual streams are we subtract the mean residual stream and then take the cosine sim.
    • We compute a separate mean per layer, across all tokens in random prompts from the Pile
  • Truncated context: We vary the number of tokens in the truncated context between 1 and 10 (this includes the token itself, so context=1 is just the token)
    • We include a BOS token at the start of the truncated prompt. (So context=10 means 11 tokens total).
      • We do this because models often treat the first token weirdly, e.g. having 20x the typical residual stream norm, so it can be used as a resting position for attention heads that don’t want to look at anything (attention must add up to 1, so it can’t be “off”). We didn’t want this to interfere with our results, especially for the context=1 case
  • We measure this at every layer, at the final residual stream in each block (i.e. after the attention and MLP).

Results

Early Layers Softly Specialise in Local Processing

In the graph below, we show the mean centred cosine sims between the full context and truncated residuals for a truncated context of length 5:

We see that the cosine sims with a truncated context of length 5 are significantly higher in early layers. However, they aren’t actually at 1, so some information from the prior context is included, it’s a soft specialisation[2]. There’s a fairly gradual transition between layer 0 and 10, after which it somewhat plateaus. Interestingly, there’s an uptick in the final layer[3]. Even if we give a truncated context of length 10, it’s still normally not near 1.

One possible explanation for these results is that the residual stream is dominated by the current token, and that each layer is a small incremental update - of course truncation won’t do anything! This doesn’t involve any need for layers to specialise - later residuals will have had more incremental updates and so have higher difference. However, we see that this is false by contrasting the blue and red lines - truncating to the five most recent tokens has much higher cosine sim than truncating to just the current token (and a BOS token), even just after layer 0, suggesting early layers genuinely do specialise into nearby tokens.

Error Analysis: Which Tokens Have Unusually Low Cosine Sim?

In the previous section we only analysed the median of the mean centred cosine sim between the truncated context and full context residual. Summary statistics can be misleading, so it’s worthwhile to also look at the full distribution, where we see a long negative tail! What’s up with that?

When inspecting the outlier tokens, we noticed two important clusters: punctuation tokens, and common words. We sorted into a few categories and looked at the cosine sim for each of these:

  • Is_newline, is_full_stop, is_comma - whether it’s the relevant punctuation character

  • Is_common: whether it’s one of a hand-created list of common words[4], possible prepended with a space

  • Is_alpha: whether it’s both not a common word, and is made up of letters (possibly prepended with a space, any case allowed)

  • Is_other: the rest

Even after just layer 0 with context length of 10, we see that punctuation is substantially lower, common words and other are notably lower, while alpha is extremely high.

Our guess is that this is a result of a mix of several mechanisms:

  • Word fragments (in the is_alpha category) are more likely to be part of multi-token words and detokenization before much processing can be done, while many of the other categories have a clear meaning without needing to refer to recent prior tokens[5]. This means that long-range processing can start earlier

  • Early full stops or newlines are sometimes used as “resting positions” with very high norm, truncating the context may turn them from normal punctuation to resting positions

  • Pronouns may be used to track information about the relevant entity (their name, attributes about them, etc)

  • Commas have been observed to summarise sentiment of the current clause, which may be longer than 10 tokens, and longer range forms of summarisation seem likely.

  • More eclectic hypotheses:

    • On e.g. a full stop or newline, the model may want to count how many have come before, e.g. to do variable binding and identify the current sentence.

  1. But it would be pretty surprising if literally no long-range processing happens in early layers, e.g. we know GPT-2 Small has a duplicate token head in layer 0. ↩︎

  2. It’s a bit hard to intuitively reason about about cosine sims, our best intuition is to look at the squared cosine sim (fraction of norm explained). If there’s 100 independently varying pieces of information in the residual stream, and a cosine sim of 0.9, then the fraction of norm explained is 0.81, suggesting about 81 of those 100 pieces of information is shared. ↩︎

  3. Our guess is that this is because the residual stream on a token is both used to literally predict the next token, and to convey information to future tokens to predict their next token (e.g. the summarisation motif). Plausibly there’s many tokens where predicting the literal next token mostly requires the local context (e.g. n-grams) but there’s longer-range context that’s useful for predicting future tokens. We’d expect the longer-range stuff to happen in the middle, so by the end the model can clean up the long-range stuff and focus on just the n-grams. We’re surprised the uptick only happens in the final layer, rather than final few, as our intuition is that the last several layers are spent on just the next token prediction. ↩︎

  4. The list ["and", "of", "or", "in", "to", "that", "which", "with", "for", "the", "a", "an", "they", "on", "is", "their", "but", "are", "its", "i", "we", "it", "at"]. We made this by hand by repeatedly looking at tokens with unusually low cosine sim, and filtering for common words ↩︎

  5. This isn’t quite true, e.g. “.” means something very different at the end of a sentence vs “Mr.” vs “C.I.A” ↩︎

New Comment