Review

I'm writing this post to discuss solutions to the August challenge, and present the challenge for this September. Apologies for this coming so late in the month (EAGxBerlin was taking up much of my focus in the first half of this month).

If you've not read the first post in this sequence, I'd recommend starting there - it outlines the purpose behind these challenges, and recommended prerequisite material.

September Problem

The problem for this month (or at least as much of the month as remains!) is interpreting a model which has been trained to perform simple addition. The model was fed input in the form of a sequence of digits (plus special + and = characters with token ids 10 and 11), and was tasked with predicting the sum of digits one sequence position before they would appear. Cross entropy loss was only applied to these four token positions, so the model's output at other sequence positions is meaningless.

The model is attention-only, with 2 layers, and 3 heads per layer. It was trained with layernorm, weight decay, and an Adam optimizer with linearly decaying learning rate.

You can find more details on the Streamlit page. Feel free to reach out if you have any questions!

August Problem - Solutions

You can read full solutions on the Streamlit page, or on my personal website. Both of these sources host interactive charts (Plotly and Circuitsvis) so are more suitable than a LessWrong/AF post to discuss the solutions in depth. However, I've presented a much shorter version of the solution below. If you're interested in these problems, I'd recommend having a try before you read on!


The key idea with this model is path decomposition (see the corresponding section of A Mathematical Framework for Transformer Circuits). There are several different important types of path in this model, with different interpretations & purposes. We might call these negative paths and positive paths. The negative paths are designed to suppress repeated tokens, and the positive paths are designed to boost tokens which are more likely to be the first unique token.

Let's start with the negative paths. Some layer 0 heads are duplicate token heads; they're composing with layer 1 heads to cause those heads to attend to & suppress duplicated tokens. This is done both with K-composition (heads in layer 1 suppress duplicated tokens because they attend to them more), and V-composition (the actual outputs of the DTHs are used as value input to heads in layer 1 to suppress duplicated tokens). Below is an example, where the second and third instances of a attend back to the first instance of a in head 0.2, and this composes with head 1.0 which attends back to (and suppresses) all the duplicated a tokens.

Now, let's move on to the positive paths. Heads in layer 0 will attend to early tokens which aren't the same as the current destination token, because both these bits of evidence correlate with this token being the first unique token at this position (this is most obvious with the second token, since the first token is the correct answer here if and only if it doesn't equal the second token). Additionally, the outputs of heads in layer 0 are used as value input to heads in layer 1 to boost these tokens, i.e. as a virtual OV circuit. These paths aren't as obviously visible in the attention probabilities, because they're distributed: many tokens will weakly attend to some early token in a layer-0 head, and then all of those tokens will be weakly attended to by some layer-1 head. But the paths can be seen when we plot all the OV circuits, coloring each value by how much the final logits for that token are affected at the destination position:

 

Another interesting observation - these paths aren't just split by head, they're also split by character. We can see that each path has some positive tokens or some negative tokens (or in a few cases, both!). To take the example at the start of this section, heads 0.2 and 1.0 were composing to suppress the a token when it was duplicated - this makes sense given the a-value for this path in the plot above is negative. But this path also has a positive value for c, and on inspection we find evidence that this is a positive path: head 0.2 will attend to early instances of c, and head 1.0 will pick up on these to boost the logits for c. As mentioned, this effect is harder to spot when you look at attention patterns, but we can pick up on it when you perform logit attribution for each path - consistently we find that the most positive paths for the correct token are the ones you'd expect from looking at the plot above. For example, here's a case where c is the first token (and is correct for much of the sequence), and this is reflected in the logit attribution:

Note - scales may be slightly off for this graph, but the important part is the relative scales - we can see the most boosting paths for token c are exactly the ones we'd expect from the previous OV plot.

 

Note that the heads' functionality is split across the whole vocabulary. 1.0 handles boosting / suppression for [a, c], 1.1 handles [d, e, f, j], and 1.2 handles [b, g, h, i]. These sets are disjoint, and their union is the whole vocabulary. We might guess that the head would perform better with a more even split (i.e. 3-3-4 rather than 2-4-4). In fact, there are lots of ways this transformer could have implemented a better algorithm than the one it actually learned, but inductive biases also play a big role here.

You can see all the code used to generate these plots at either of the two links I provided at the start of this section.

Best Submissions

We received several excellent solutions to this problem. A special mention goes to the solutions by Andy Arditi and Connor Kissane, which I considered to be joint best as they both presented thorough and rigorous analysis which cut to the heart of what this model was doing. An honorable mention goes to Rick Goldstein.

Best of luck for this and future challenges!

New Comment