Introduction

This post is intended to summarise some of the interesting facts I found while playing around with toy transformers on a synthetic task. Initially, I was trying to explore the sudden emergence of induction heads when you go from 1 to 2 layers. The full details of what I did can be found here

Task

The task is to model a sequence of 24 letters where the sequence is made up of 4 repeated blocks of 6 letters. Induction heads should be able to predict all the letters after the first 6.

Example: ABCDEFABCDEFABCDEFABCDEF

I trained attention-only models with 1 or 2 layers and experimented using smeared keys*.

*The concept of smeared keys comes from Olsson, et al. and is a way of forcing induction heads in 1-layer transformers.

Architecture 

I used a simplified version of the decoder-only transformer architecture.

Simplifications:

  • No layer normalization
  • No MLP blocks
  • No biases
  • Sinusoidal positional encodings are added to the Q and K matrices before the attention is computed(to avoid adding the positional encodings to the residual stream)

Hyperparameters:

  • 1 or 2 layer(s)
  • 4 heads
  • Inner dimension: 32
  • Learning rate: 0.001
  • Weight decay: 1.0
  • Steps: 10000
  • Batch size: 256
  • 5-10k parameters

The model was trained on randomly generated data drawn uniformly from the set of all valid sequences. 

Observations

1) Toy transformers trained on this task have specialised circuits which don't exist in standard transformers, even when the circuits in more general models would be sufficient to solve the problem. 

  • 1A) The direct path prevents a letter from being repeated twice in a row. This makes sense because the probability of the next letter being X given that the current letter is X is 1/26. This is completely different from the bigrams statistics in more general models. 
  • 1B) "Anti-induction heads" seem to emerge which find positions which can't be the next token and prevent them from being repeated. 

2) You can find an example of in-distribution failure and predict the model will behave out of distribution by reversing engineering the model. 

  • 2A) Sequences made up of repeated letters (i.e. AABBCC) have a much higher loss than average in all 3 models. I found this sequence by analysing the weights. 
  • 2B) Out of distribution strings with 2 different blocks of 12 letters of the form like XYZRSTABCDEF behave differently in a 1-layer model with smeared keys and 2-layers, even though they perform nearly identically in distribution.

3) You can describe the algorithms used by the models in words. 

Note: When I refer to "standard" transformers I am referring to ones trained on a large corpus of text for a single epoch.

Evidence:

The evidence comes from looking at attention patterns and eigenvalue analysis based on the work of Elhage, et al. To detect copying behaviour, you can look at the sign of the eigenvalues of matrices that map vectors to the same space. A nice summary statistic to capture the sign of the eigenvalues is . When this copying metric is close to 1.0 (i.e. so all eigenvalues are positive), this corresponds to copying the tokens which are attended to. Values near -1.0 correspond to reducing the probability of those tokens. This task can be solved by keeping or suppressing previous tokens, so the imaginary part is very close to 0.

1A + 2A)  Specialised circuits & in-distribution failure

The copy metric for the direct path(embedding matrix followed by unembedding matrix) in all 3 models yields a number very close to -1.0. This means that the same letter will never be repeated twice in a row. In addition, heads 1 and 4 in the 1-layer smeared key model and all the heads in layer 1 of the 2-layer model perform a similar function. However, some special sequences described in 2A, like AABBCC..., QQRRTT..., require the model to repeat the same letter twice in a row. 

 1-layer smeared key model2-layer model
Cross-entropy loss on random sequences from the test set1.241.40
Cross-entropy loss on sequences with repeated letters2.31 (+86%)2.31 (+65%)

So some special subset of all valid sequences performs on average much worst than ones drawn uniformly at random.

1B) "Anti-induction" heads

This occurs most clearly in the 1-layer smeared key model. Heads 1 and 4 learn to attend to the previous few letters (see attention patterns) and suppress them due to the copy metric being fairly close to -1.0. My guess is that "anti-induction heads" emerge due to the ratio of heads to possible tokens being 4 to 6. Hence, the model can meaningfully improve by eliminating bad choices. In more general-purpose transformers, the ratio of heads to tokens is much smaller, so eliminating bad choices is not very important.

2B) Different behaviour out of distribution of seemingly similar models

The 2-layer model prevents the last ~4-6 letters from being copied and duplicates the first instance of the next token based on the attention patterns. 

How does it search backwards to find the first copy of the token after the current one?

It seems that via K-composition in the second layer, heads are using the outputs of head 2 in the first layer to shift the tokens back. The eigenvalues in the QK circuit and the head 2 layer 1 attention pattern support this(again see link for details). Whereas the smeared transformer seems to mostly focus on the previous copy of the next token, not the first one.

So based on this theory, given a sequence like XYZRSTABCDEF, the 2-layer model would predict XYZRSTXYZRST and the 1-layer ABCDEFABCDEF. I'll call these options 1 and 2 respectively. Running the models on many such examples and computing the loss on the last 12 letters yield the following:

 1-layer smeared key model2-layer key model
Cross-entropy loss for option 10.990.73 (-26% vs 1-layer)
Cross-entropy loss for option 20.736.42 (+780% vs 1-layer)

So despite having a very similar loss on the test set (1.24 vs 1.40) they behave completely differently on data outside the distribution of the training set!

3) Algorithms in words

Algorithm 1) One-layer transformer

  • Copy the last ~10 letters but favour letters close to the current one (Heads)
  •  Don't repeat the same letter twice in a row (Direct path)

Algorithm 2) One-layer transformer with smeared keys

  • Copy the letter 5 back from the current letter (Heads 2 and 3)
  • Don't repeat any of the ~3 previous letters(including the current letter) (Heads 1, 4 and the direct path)

Algorithm 3) Two-layer transformer (less confident)

  • Don't repeat the last ~4-6 letters
    • Don't repeat the same letter twice in a row (Direct path)
    • Don't repeat the last ~3 letters (Heads in layer 1)
    • Copy the last ~3 logits (Heads in layer 2) because the logits come from layer 1 this stops the last ~4-6 letters from being repeated
  • Copy the first occurrence of the next letter (Heads in layer 2 using K-composition with head 2 in layer 1)

[The next section is only slightly better than a random guess.]

Implications

  •  Fine-tuned models like Codex could work in different ways from more general models like GPT3. 
  • Models learning specialized circuits, even if those circuits are not strictly needed, could make interpretability much harder.
  • Interpretability is fun! I've started doing this kind of work as a hobby for the past ~2 months and really enjoyed the mathematical flavour of it.
New Comment