*This is a more detailed look at our work applying **causal scrubbing** to induction heads. The results are also summarized **here**.*

In this post, we’ll apply the causal scrubbing methodology to investigate how induction heads work in a particular 2-layer attention-only language model.^{[1]} While we won’t end up reaching hypotheses that are fully specific or fully human-understandable, causal scrubbing will allow us to validate claims about which components and computations of the model are important.

We'll first identify the induction heads in our model and the distribution over which we want to explain these heads' behavior. We'll then show that an initial naive hypothesis about how these induction heads mechanistically work—similar to the one described in __the Induction Heads paper__ (Olsson et al 2022)—only explains 35% of the loss. We’ll then go on to verify that a slightly more general, but still reasonably specific, hypothesis can explain 89% of the loss. It turns out to be the case that induction heads in this small model use information that is flowing through a variety of paths through the model – not just previous token heads and the embeddings. However, the important paths can be constrained considerably – for instance, only a small number of sequence positions are relevant and the way that attention varies in layer 0 is not that important.

As with the paren balance checker post, this post is mostly intended to be pedagogical. You can treat it as an initial foray into developing and testing a hypothesis about how induction heads work in small models. We won't describe in detail the exploratory work we did to produce various hypotheses in this document; we mostly used standard techniques (such as looking at attention patterns) and causal scrubbing itself (including looking at internal activations from the scrubbed model such as log-probabilities and attention patterns).

The experiments prescribed by causal scrubbing in this post are roughly equivalent to performing resampling ablations on the parts of the rewritten model that we claim are unimportant. For each ‘actual’ datum we evaluate the loss on, we’ll always use a single ‘other’ datum for this resampling ablation.

Throughout this document we measure hypothesis quality using the percentage of the loss that is recovered under a particular hypothesis. This percentage may exceed 100% or be negative, it’s not actually a fraction. See __the relevant section in the appendix post__ for a formal definition.

Note that, in these examples, we’re not writing out formal hypotheses, as defined in __our earlier post__, because the hypotheses are fairly trivial while also being cumbersome to work with. In brief, our is identical to with all the edges we say don’t matter removed, and every node computing the identity.

We studied a 2-layer attention-only model with 8 heads per layer. We use L.H as a notation for attention heads where L is the zero-indexed layer number and H is the zero-indexed head number.

Further details about the model architecture (which aren’t relevant for the experiments we do) can be found in the appendix.

Induction heads, originally described in __A Mathematical Framework for Transformer Circuits__ (Elhage et al 2021), are attention heads which empirically attend from some token `[A]`

back to earlier tokens `[B]`

which follow a previous occurrence of `[A]`

. Overall, this looks like `[A][B]...[A]`

where the head attends back to `[B]`

from the second `[A]`

.

Our first step was to identify induction heads. We did this by looking at the attention patterns of layer 1 heads on some text where there are opportunities for induction. These heads often either attend to the first token in a sequence, if the current token doesn’t appear earlier in the context, or look at the token following the previous occurrence of the current token.

Here are all the attention patterns of the layer 1 heads on an example sequence targeted at demonstrating induction: “Mrs. Dursley, Mr. Dursley, Dudley Dursley”

Two heads seem like possible induction heads: 1.5 and 1.6. We can make this more clear by looking more closely at their attention patterns: for instance, zooming in on the attention pattern of 1.6 we find that it attends to the sequence position corresponding to the last occurrence of "`[ley]`

":

Within this, let’s specifically look at the attention from the last ‘urs’ token (highlighted in the figure above).

A closer look at the attention pattern of head 1.5 showed similar behavior.

Previous token heads are heads that consistently attend to the previous token. We picked out the head that we thought was a previous token head by eyeballing the attention patterns for the layer zero heads. Here are the attention patterns for the layer zero heads on a short sequence from OpenWebText:

Here’s a plot of 0.0’s attention pattern:

So you can see that 0.0 mostly attends to the previous token, though sometimes attends to the current token (e.g. on “ to”) and sometimes attends substantially to the `[BEGIN]`

token (e.g. from “ Barcelona”).

Let's define a "next-token prediction example" to be a context (a list of tokens) and a next token; the task is to predict the next token given the context. (Normally, we train autoregressive language models on all the prefixes of a text simultaneously, for performance reasons. But equivalently, we can just think of the model as being trained on many different next-token prediction examples.)

We made a bunch of next-token prediction examples in the usual way (by taking prefixes of tokenized OWT documents), then filtered to the subset of these examples where the last token in the context was in a particular whitelist of tokens.

We chose this whitelist by following an approach which is roughly 'select tokens such that hard induction is very helpful over and above bigrams'--see __the appendix__ for further details. Code for this token filtering can be found in the appendix and the exact token list is linked. Our guess is that these results will be fairly robust to different ways of selecting the token whitelist.

So, we didn't filter based on whether induction was a useful heuristic on this particular example, or on anything about the next-token; we only filtered based on whether the last token in the context was in the whitelist.

For all the hypotheses we describe in this post, we’ll measure the performance of our scrubbed models on just this subset of next-token prediction examples. The resulting dataset is a set of sequences whose last token is somewhat selected for induction being useful. Note that evaluating hypotheses on only a subset of a dataset, as we do here, is equivalent to constructing hypotheses that make no claims on tokens other than our “inductiony” tokens, and then evaluating these weaker hypotheses on the whole dataset.

We want to explain the performance of our two-layer attention-only model. Its performance is measured by the following computational graph:

We’re taking the token embeddings (emb) and running them through the model, then calculating the log-loss of the model on the actual next token. The model is composed of two attention layers (with layernorm), which we’re writing as a0 and a1.

To start out our investigation, let’s see how much performance is damaged if we replace the induction head’s outputs with their outputs on random other sequences. To measure this, we rewrite our model to separate out the induction heads from the other layer 1 heads:

Now we consider passing in different inputs into the induction heads and the other heads.

We run the model *without the induction heads* on the (emb, next token) pairs that correspond to sequences in our filtered next-token-prediction dataset, while running the induction heads on a different sequence (encoded as emb2).

Note that if we ran this computational graph with emb2 = emb, we’d exactly recover the performance of the original model–we’ve rewritten the model to be exactly mathematically equal, except that we’re now able to pass in inputs that are different for different paths through the model.

If the induction heads don’t contribute at all to model performance on this task, this experiment would return the same loss as the original model.

When you run the scrubbed computation, the loss is 0.213. The original loss on this dataset was 0.160, and the difference between these losses is 0.053. This confirms that the induction heads contribute significantly to the performance of the original model for this subset of tokens.

Going forward, we'll report the fraction of this 0.053 loss difference that is restored under various scrubs.

For every experiment in this post, we use the same choice of emb2 for each (emb, next token) pair. That is, every dataset example is paired with a single other sequence^{[2]} that we’ll patch in as required; in different experiments, the way we patch in the other sequence will be different, but it will be the same other sequence every time. We do this to reduce the variance of comparisons between experiments.

This is the standard picture of induction:

- We have a sequence like “Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Durs”. “Dursley” is tokenized as | D|urs|ley|. And so a good prediction from the end of this sequence is “ley”. (We’ll refer to the first “urs” token as A, the first “ley” token as B, and the second “urs” token as A’.)
- There’s a previous-token head in layer 0 which copies the value at A onto B.
- The induction head at A’ attends to B because of an interaction between the token embedding at A’ and the previous-token head output at B.
- The induction head then copies the token embedding of B to its output, and therefore the model proposes B as the next token.

To test this, we need to break our induction heads into multiple pieces that can be given inputs separately. We first expand the node (highlighted in pink here):

So we’ve now drawn the computation for the keys, queries, and values separately. (We’re representing the multiplications by the output matrix and the value matrix as a single “OV” node, for the same reasons as described in the __“Attention Heads are Independent and Additive”__ section of A Mathematical Framework for Transformer Circuits.)

Our hypothesis here involves claims about how the queries, keys, and values are formed:

- values for the induction head are produced only from the token embeddings via the residual stream with no dependence on a0
- queries are also produced only from the token embeddings
- keys are produced only by the previous-token head

Before we test them together, let’s test them separately.

The hypothesis claims that the values for the induction head are produced only from the token embeddings via the residual stream, with no dependence on a0. So, it it shouldn’t affect model behavior if we rewrite the computation such that the a1 induction OV path is given the a0 output from emb2, and so it only gets the information in emb via the residual connection around a0:

When we do this scrub, the measured loss is 90% of the way from the baseline ablated model (where we ran the induction heads on emb2) to the original unablated model. So the part of the hypothesis where we said only the token embeddings matter for the value path of the induction heads is somewhat incorrect.

We can similarly try testing the “the queries for induction heads are produced only from the token embeddings” hypothesis, with the following experiment:

The fraction of the loss restored in this experiment is 51%, which suggests that this part of the hypothesis was substantially less correct than the part about how the induction head values are produced.

Finally, we want to test the final claim in our hypothesis; that the key used by the induction head is produced only by the previous-token head.

To do this, we first rewrite our computational graph so that the induction key path takes the previous-token head separately from the other layer zero heads.

This experiment here aims to evaluate the claim that the only input to the induction heads that matters for the keys is the input from the previous-token head.

However, this experiment wouldn’t test that the previous-token head is actually a *previous* token head. Rather, it just tests that this particular head is the one relied on by the induction heads.

We can make a strong version of this previous token head claim via two sub-claims:

- The attention pattern is unimportant (by which we mean that the relationship between the attention pattern and the OV is unimportant, as discussed in
__this section__of our earlier post) - All that matters for the OV is the previous sequence position

We’ll implement these claims by rewriting the model to separate out the parts which we claim are unimportant and then scrubbing these parts. Specifically, we’re claiming that this head always operates on the previous token through its OV (so we connect that to “emb”); and its attention pattern doesn’t depend on the current sentence (so we connect that to “emb2”). We also connect the OV for tokens that are not the previous one to “emb2”.

The resulting computation for the previous-token head is as follows:

So we’ve run the OV circuit on both emb and emb2, and then we multiply each of these by a mask so that we only use the OV result from emb for the previous token. Prev mask is a matrix that is all zeros except for the row below the diagonal (corresponding to attention to the previous token). Non prev mask is the difference between prev mask and the lower triangular mask that we normally use to enforce that attention only looks at previous sequence positions.

And so, our overall experiment is as follows, where the nodes of the model corresponding to the previous token head are shown in pink:

This fraction of the loss restored by this experiment is 79%.

Next we want to scrub all these paths (i.e. do all these interventions) simultaneously.

The fraction of the loss that this restores is 35%.

Using causal scrubbing, we’ve found that our initial naive hypothesis is quite incorrect for these induction heads.

To recap the results, the fractions of loss restored are:

- Scrubbing all of the input to Q except the embeddings: 51%.
- Scrubbing all of the input to K, except the previous token part of the previous-token head: 79%
- Scrubbing all of the input to V except the embeddings: 90%
- Doing all of these at once: 35%

These numbers weren’t very surprising to us. When we described this experiment to some of the authors of the induction heads paper, we asked them to guess the proportion of loss that this would recover, and their answers were also roughly in the right ballpark.

How might our previous hypothesis be missing important considerations? Or, to put it differently, what important information are we scrubbing away?

One possibility is that it’s common for attention heads to attend substantially to the current sequence position (you’ll see this if you look at the attention patterns included in the “Identification” section). This attention results in the token’s representation being transformed in a predictable way. And so, when the induction heads are learning to e.g. copy a token value, they’ll probably set up their V matrix to take into account the average attention-to-current-token of the layer zero heads.

We would like to express the hypothesis that the induction head interacts with all the layer zero heads, but through their average attention-to-current-token. That is, we hypothesize that the induction head’s behavior isn’t importantly relying on the ways that a0 heads vary their attention depending on context; it’s just relying on the effect of the a0 head OV pathway, ignoring correlation with the a0 attention pattern.

Similarly, there might be attention heads other than the previous token head which, on average, attend substantially to the previous token; the previous hypothesis also neglects this, but we’d like to represent it.

Here’s the complete experiment we run. Things to note:

- We’ve drawn the “emb” and “emb2” nodes multiple times. This is just for ease of drawing–we’ll always use the same value the two places we drew an emb node.
- The main point of this experiment is that the layer zero attention patterns used by the induction heads always come from emb2, so the induction heads can’t be relying on any statistical relationship between the layer zero attention pattern and the correct next token.

Running parts of this individually (that is, just scrubbing one of Q, K, or V in the induction heads, while giving the others their value on emb) and all together (which is what is pictured) yields the following amounts of loss recovered:

Q: 76%

K: 86%

V: 97%

All: 62%

So, we've captured V quite well with this addition, but we haven’t yet captured much of what’s happening with K and Q.

One theory for what could be going wrong with Q and K is that we need to take into account other sequence positions. Specifically, maybe there's some gating where K only inducts on certain 'B' tokens in AB...A, and maybe the induction heads fire harder on patterns of the form XAB...XA, where there are two matching tokens (for example, in the earlier Dursley example, note that the two previous tokens | D| and |urs| both matched.). This is certainly not a novel idea—prior work has mentioned fuzzy matching on multiple tokens.

So, we'll considerably expand our hypothesis by including 'just the last 3 tokens' for K and Q (instead of just previous and just current). (By last three, we mean current, previous, and previous to previous.)

It’s getting unwieldy to put all this in the same diagram, so we’ll separately draw how to scrub K, Q, and V. The OV activations are produced using the current token mask, and the Q and K are produced using the “last 3 mask”. Both use the direct path from emb rather than emb2.

Given these, we can do the experiments for this hypothesis by substituting in those scrubbed activations as desired:

And the numbers are:

Q: 87%

K: 91%

V: 97% (same as previous)

All: 76%

This improved things considerably, but we're still missing quite a bit. (We tested using different subsets of the relative sequence positions for Q and K; using the last three for both was the minimal subset which captures nearly all of the effect.)

If you investigate what heads in layer 0 do, it turns out that there are some heads which often almost entirely attend to occurrences of the current token, even when it occurred at earlier sequence positions.

The below figure shows the attention pattern of 0.2 for the query at the last ' Democratic' token:

So you can see that 0.2 attended to all the copies of “ Democratic”.

Because this is a layer zero head, the input to the attention head is just the token embedding, and so attending to other copies of the same token leads to the same output as the head would have had if it had just attended to the current token. But it means that on any particular sequence, this head’s attention pattern is quite different from its attention pattern averaged over sequences. Here is that head’s attention pattern at that sequence position, averaged over a large number of sequences:

On average, this head attends mostly to the current token, a bit to the `[BEGIN]`

token, and then diffusely across the whole sequence. This is the average attention pattern because tokens that match the current token are similarly likely to be anywhere in the context.

These heads have this kind of attend-to-tokens-that-are-the-same-as-the-current-token behavior for most of the tokens in the subset of tokens that we picked (as described in “Picking out inductiony tokens”). This is problematic for our strategy where we scrub the attention probabilities because the expected attention probability on tokens matching the current token might be 0.3, even though the model always only attends to tokens matching the current token.

There are two-layer 0 heads which most clearly have this behavior, 0.2 and 0.5, as well as 0.1, which somewhat has this behavior.

(These heads don't *just *do this. For instance, in the attention pattern displayed above, 0.2 also attends to ' Democratic' and ' Party' from the ' GOP' token. We hypothesize this is related to 'soft induction'^{[3]}, though it probably also has other purposes – for instance directly making predictions from bigrams and usages in other layer 1 heads.)

In addition to this issue with the self-attending heads, the previous token head also sometimes deviates from attending to the previous token, and this causes additional noise when we try to approximate it by its expectation. So, let’s try the experiment where we run the previous token head and these self-attending heads with no scrubbing or masking.

So we’re computing the queries and keys for the induction heads as follows:

And then we use these for the queries and keys for the induction heads. We use the same values for the induction heads as in the last experiment. Our experiment graph is the same as the last experiment, except that we’ve produced Q and K for the induction heads in this new way.

Now we get:

Q: 98%

K: 97%

V: 97% (same as previous)

All: 91%

We’re happy with recovering this much of the loss, but we aren’t happy with the specificity of our hypothesis (in the sense that a more specific hypothesis makes more mechanistic claims and permits more extreme intervention experiments). Next, we’ll try to find a hypothesis that is more specific while recovering a similar amount of the loss.

So we’ve observed that the self-attending heads in layer zero are mostly just attending to copies of the same token. This means that even though these heads don’t have an attention pattern that looks like the identity matrix, they should behave very similarly to how they’d behave if their attention pattern *was* the identity matrix. If we can take that into account, we should be able to capture more of how the queries are formed.

To test that, we rewrite the self-attending heads (0.1, 0.2, 0.5) using the identity, where “identity attention” means the identity matrix attention pattern:

This is equal to calculating 0.1, 0.2, and 0.5 the normal way, but it permits us to check the claim “The outputs of 0.1, 0.2, and 0.5 don’t importantly differ from what they’d be if they always attended to the current token”, by using just the left hand side from the real input and calculating the “error term” using the other input.

Let’s call this “0.1, 0.2, 0.5 with residual rewrite”.

So now we have a new way of calculating the queries for the induction heads:

Aside from the queries for the induction heads, we run the same experiment as refined hypothesis 2.

And this yields:

Q: 97%

K: 91% (same as refined hypothesis 2)

V: 97% (same as refined hypothesis 2)

All: 86%

We've now retained a decent fraction of the loss while simultaneously testing a reasonably specific hypothesis for what is going on in layer 0.

In this post, we used causal scrubbing to test a naive hypothesis about induction heads, and found that it was incorrect. We then iteratively refined four hypotheses using the scrubbed expectation as a guide. Hopefully, this will serve as a useful example of how causal scrubbing works in simple settings.

Key takeaways:

- We were able to use causal scrubbing to narrow down what model computations are importantly involved in induction.
- In practice, induction heads in small models take into account information from a variety of sources to determine where to attend.

- The model uses the
__shortformer positional encodings__(which means that the positional embeddings are added into the Q and K values before every attention pattern is computed, but are not provided to the V) - The model has layer norms before the attention layers
- The model was trained on the
__openwebtext dataset__ - Its hidden dimension is 256
- We ran causal scrubbing on validation data from openwebtext with sequence length 300

- Choose beta and threshold
- Then for all sequential pairs of tokens, AB, in the corpus we compute:
- The log loss of the bigram probabilities (via the full bigram matrix).
- The log loss of 'the beta-level induction heuristic' probabilities which we compute as:
- Find all prior occurrences of A
- Count the number of these prior occurrences which are followed by B. Call this the matching count . Let the remaining occurrences be the not matching count, .
- Starting from the bigram statistics, we add to the logit of B and to the not-B logit. That is:

- Then, we compute the log loss of these probabilities.
- Finally, we compute the average log loss from the A for each of these heuristics. If the bigram loss is larger than the induction heuristic loss by at least some small threshold for a given A token, we include that token.

Our previous hypothesis improved the Q pathway considerably, but we’re missing quite a bit of loss due to scrubbing the attention pattern of the previous token head for K. This is due to cases where the previous token head deviates from attending to the previous token. If we sample an alternative sequence which fails to attend to the previous token at some important location, this degrades the induction loss. We’ll propose a simple hypothesis which partially addresses this issue.

Consider the following passage of text:

`[BEGIN] I discovered Chemis artwork a few weeks ago in Antwerp (Belgium) I’ve been fascinated by the poetry of his murals and also by his fantastic technique. He kindly accepted to answer a few questions for StreetArt360. Here’s his interview.\n\nHello Dmitrij, great to meet you.`

We’ll look at the attention pattern for the previous token head as weighted lines from q tokens to k tokens.

For instance, here’s the start of the sequence:

Some utf-8 characters consist of multiple bytes and the bytes are tokenized into different tokens. In this case, those black diamonds are from an apostrophe being tokenized into 2 tokens.

Note that the previous token head exhibits quite different behavior on ` I’ve`

(where the apostrophe is tokenized into those 2 black diamonds). Specifically, `ve`

skips over the apostrophe to attend to ` I`

. The token ` been`

also does this to some extent.

Another common case where the previous token head deviates is on punctuation. Specifically, on periods it will often attend to the current token more than typical.

While there are a variety of different contexts in which the previous token head importantly varies its attention, we’ll just try to explain a bit of this behavior. We’ll do this by identifying a bunch of cases where the head has atypical behavior.

It turns out that the model special cases a variety of different quote characters. We’ll specifically reference `'`

, `“`

, ` “`

, and `”`

.

It’s a bit hard to read this, so here’s a zoomed in version with a font that makes the different characters more obvious.

These quote characters each consist of two tokens and the head has atypical behavior if any of those tokens is the previous one.

It turns out the model also sometimes has atypical behavior if the previous token is the ` a`

or ` an`

token:

And, it has atypical behavior if the current token is `.`

or ` to`

.

Overall, here is simple classifier for whether or not a given token is atypical:

- Is the previous token one of the bytes from any of:
`’`

,`“`

,`“`

, and`”`

? If so, atypical. - Is the previous token one of
`a`

or`an`

? If so, atypical. - Is the current token one of
`.`

or`to`

? If so, atypical. - Otherwise, typical.

And we’ll propose that it only matters whether or not the attention from the current location should be ‘typical’ or ‘atypical’. Then, we can test this hypothesis with causal scrubbing by sampling the attention pattern from the current location from a different sequence which has the same ‘typicality’ at that location.

This hypothesis is clearly somewhat wrong – it doesn’t only matter whether or not the current token is ‘typical’! For instance, the current token being `.`

or ` to`

results in attending to the previous token while the previous token being ` a`

results in attending two back. Beyond this issue, we’ve failed to handle a bunch of cases where the model has different behavior. This hypothesis would be easy to improve, but we’ll keep it as is for simplicity.

We’ll apply this augmented hypothesis for the previous token head to just the K pathway. This yields:

Q: 97% (same as refined hypothesis 4)

K: 94%

V: 97% (same as refined hypothesis 2)

All: 89%

So compared to refined hypothesis 4, this improved the loss recovered by 3% for both K individually (94 -> 97%) and everything together (86->89%). This brings us considerably closer to the overall loss from hypothesis 3 which was 91%.

^{^}These heads are probably doing some things in addition to induction; we’ll nevertheless refer to them as induction heads for simplicity and consistency with earlier work.

^{^}Note that we opt to use a single input rather than

__many__.^{^}This is where seeing AB updates you towards thinking that tokens similar to A are likely to be followed by tokens like B.

Discuss]]>

*This is a more detailed look at our work applying **causal scrubbing** to an algorithmic model. The results are also summarized **here**.*

In earlier work (unpublished), we dissected a tiny transformer that classifies whether a string of parentheses is balanced or unbalanced.^{[1]} We hypothesized the functions of various parts of the model and how they combine to solve the classification task. The result of this work was a qualitative explanation of how this model works, but one that made falsifiable predictions and thus qualified as an informal hypothesis. We summarize this explanation __below__.

We found that the high-level claims in this informal hypothesis held up well (88-93% loss recovered, see the __Methodology__). Some more detailed claims about how the model represents information did not hold up as well (72%), indicating there are still important pieces of the model’s behavior we have not explained. See the experiments __summary section__ for an explanation of each hypothesis refinement.

Causal scrubbing provides a language for expressing explanations in a formal way. A formal hypothesis is an account of the information present at every part of the model, how this information is combined to produce the output, and (optionally) how this information is represented. In this work, we start by testing a simple explanation, then iterate on our hypothesis either by improving its accuracy (which features of the input are used in the model) or specificity (what parts of the model compute which features, and optionally how). This iterative process is guided by the informal hypothesis that we established in prior work.

For a given formal hypothesis, the causal scrubbing algorithm automatically determines the set of interventions to the model that would not disturb the computation specified by the hypothesis. We then apply a random selection of these interventions and compare the performance to that of the original model.

Using causal scrubbing to evaluate our hypotheses enabled us to identify gaps in our understanding, and provided trustworthy evidence about whether we filled those gaps. The concrete feedback from a quantitative correctness measure allowed us to focus on quickly developing alternative hypotheses, and finely distinguish between explanations with subtle differences.

We hope this walk-through will be useful for anyone interested in developing and evaluating hypothesized explanations of model behaviors.

The model architecture is a three layer transformer with two attention heads and pre layernorm:

There is no causal mask in the attention (bidirectional attention). The model is trained to classify sequences of up to 40 parentheses. Shorter sequences are padded, and the padding tokens are masked so they cannot be attended to.

The training data set consists of 100k sequences along with labels indicating whether the sequence is balanced. An example input sequence is `()())()`

which is labeled *unbalanced*. The dataset is a mixture of randomly generated sequences and adversarial examples.^{[2]} We prepend a `[BEGIN]`

token at the start of each sequence, and read off the classification above this token (therefore, the parentheses start at sequence position 1).

For the experiments in this writeup we only use random, non-adversarial, inputs, on which the model is almost perfect (loss of 0.0003, accuracy 99.99%). For more details of the dataset see the __appendix__.

Our hypothesis is that the model is implementing an algorithm that is approximately as follows:

- Scan the sequence from right to left and track the nesting depth at each position. That is, the nesting depth starts at 0 and then, as we move across the sequence, increments at each
`)`

and decrements at each`(`

.

You can think of this as an "elevation profile" of the nesting level across the sequence, which rises or falls according to what parenthesis is encountered.

Important note: scanning from left to right and scanning from right to left are obviously equally effective. The specific model we investigated scans from right to left (likely because we read off the classification at position 0). - Check two conditions:
**The Equal Count Test (aka the count test)**: Is the elevation back to 0 at the left-most-parentheses? This is equivalent to checking whether there are the same number of open and close parentheses in the entire sequence.**The Above Horizon Test (aka the horizon test)**: Is the elevation non-negative at every position? This is equivalent to checking whether there is at least one open parenthesis`(`

that has not been later closed`)`

(cf. see the third example below).

- If either test fails, the sequence is unbalanced. If both pass, the sequence is balanced.

(As an aside, this is a natural algorithm; it's also similar to what codex or GPT-3 generate when prompted to balance parentheses.)

Again - this is close to the algorithm we hypothesize the model uses, but not exactly the same. The algorithm that the model implements, according to our hypothesis, has two differences from the one just described.

**1. It uses proportions instead of ‘elevation’.** Instead of computing ‘elevation’ by incrementing and decrementing a counter, we believe the model tracks the proportion of parentheses that are open in each suffix-substring (i.e. in every substring that ends with the rightmost parenthesis). This proportion contains the same information as ‘elevation’. We define:

p: the proportion of open parentheses ( in the suffix-substring starting at position i, i.e. from i to the rightmost parenthesis_{i}

Put in terms of proportions, the Equal Count Test is whether this is exactly 0.5 for the entire string (p_{1} == 0.5). The Not Beneath Horizon Test is whether this is less than or equal to 0.5 for each suffix-substring (p_{i }<= 0.5 for all i); if the proportion is less than 0.5 at any point, this test is failed.

**2. It uses a combined test of, “is the first parenthesis open, and does the sequence pass Equal Count?” **Call this the Start-Open-and-Equal-Count test, aka the count^{(} Test.

Consider the Start-Open component. Sequences that start with a closed parenthesis instead of an open one cannot be balanced: they inevitably fail to meet *at least* one of Equal Count or Not Beneath Horizon. However, the model actually computes Start-Open separately!

As the model detects Start-Open in the circuit that computes the ‘equal count’ test, we’ve lumped them together for cleaner notation:

count: Is the elevation back to 0 at the left-most-parenthesis (i.e. the Equal Count Test), and does the sequence start with open parenthesis? i.e. count^{(}^{(}:= (first parenthesis is open) & (passes Equal Count Test)

We can use these variables to define a computational graph that will compute if any sequence of n parentheses is balanced or not.

Note that we define count^{(}, horizon_{i}, and horizon_{all} to be booleans that are true if the test passes (implying that the sequence might be balanced) and false (implying the sequence definitely isn’t balanced).

We will reference the features of the above graph to make claims about what particular components of the model compute.

Our prior interpretability work suggested that the model implements this algorithm in roughly the following manner:

- Head 0.0 has an upper triangular attention pattern (recall that model uses bidirectional attention): at every query position it pays roughly-equal attention to all following sequence positions and writes in opposite directions at open and close parentheses. These opposite directions are analogous to “up” and “down” in the
__elevation profile__. Thus, head 0.0 computes every p_{i}and writes this in a specific direction. - The MLPs in layers 0 and 1 then transform the p
_{i}into binary features. In particular, at position 1 they compute the count^{(}test, and at every sequence position they compute the horizon test for that position. - Head 1.0 and 2.0 both copy the information representing the count
^{(}test from position 1 (the first parentheses token) to position 0 (the [BEGIN] token where the classifier reads from). - Head 2.1 checks that the horizon test passed at all positions and writes this to position 0.

A consequence of this hypothesis is that at position 0, head 2.0 classifies a sequence as balanced if it passes the count^{(} test, and classifies it as unbalanced if it fails the test. Head 2.1 does the same for the horizon test. As some evidence for this hypothesis, let’s look at an attribution experiment.

We run the model on points sampled on the random data set, which may each pass or fail either or both of the tests. We can measure the predicted influence on the logits from the output of heads 2.0 and 2.1.^{[3]}

For each data point, we plot these two values in the x and y axes. If each head independently and perfectly performs its respective test, we should expect to see four clusters of data points:

- Those that pass both tests (i.e. are balanced) are in the top right: both heads classify them as balanced, so their x and y positions are positive.
- Unbalanced sequences, which fail both tests, are points in the bottom left.
- Sequences that pass only one of the tests should be in the top left or bottom right of the plot.

This is the actual result:

The result roughly matches what we expected, but not entirely.

The part that matches our expectations: the green (balanced) points are consistently classified as balanced by the two heads, and the orange (count^{(} failure only) points are consistently classified as balanced by 2.1 and unbalanced by 2.0.

However, the picture for the other clusters does not match our expectations; this shows that our hypothesis is flawed or incomplete. The pink points fail only the horizon test, and should be incorrectly classified as balanced by 2.0, and correctly classified as unbalanced by 2.1. In reality, 2.0 often ‘knows’ that these sequences are unbalanced, as evidenced by about half of these points being in the negative x axis. It must therefore be doing something other than the count^{(} test, which these sequences pass. The purple points, which fail both the count^{(} and horizon tests, are sometimes incorrectly thought to be balanced by 2.1, so head 2.1 cannot be perfectly performing the horizon test. In Experiment 3, we’ll show that causal scrubbing can help us detect that this explanation is flawed, and then derive a more nuanced explanation of head 2.1’s behavior.

We use the causal scrubbing algorithm in our experiments. To understand this algorithm, we advise reading __the introduction post__. Other posts are not necessary to understand this post, as we’ll be talking through the particular application to our experiments in detail.^{[4]}

Following the causal scrubbing algorithm, we rewrite our model, which is a computational DAG, into a tree that does the same computation when provided with multiple copies of the input. We refer to the rewritten model as the *treeified model*. We perform this rewrite so we can provide separate inputs to different parts of the model–say, a reference input to the branch of the model we say is important, and a random input to the branch we say is unimportant. We’ll select sets of inputs randomly conditional on them representing the same information, according to our hypothesis (see the experiments for how we do this), run the treeified model on these inputs, and observe the loss. We call the treeified model with the separate inputs assigned according to the hypothesis the “scrubbed model”.

Before anything else, we record the loss of the model under two trivial hypotheses: “everything matters” and “nothing matters”. If a hypothesis we propose is perfect, we expect that the performance of the scrubbed model is equal to that of the baseline, unperturbed model. If the information the hypothesis specifies is unrelated to how the model works, we expect the model’s performance to go down to randomly guessing. Most hypotheses we consider are somewhere in the middle, and we express this as a *% loss recovered* between the two extremes. For more information on this metric, refer to the relevant section __here__.

We run a series of experiments to test different formalizations of (parts of) the informal hypothesis we have about this model.

We start with a basic claim about our model: that there are only three heads whose direct contribution to the classifier head is important: 1.0 and 2.0 which compute count^{(}, and 2.1 which computes the horizon test. We then improve this claim in two ways:

**Specificity:**Making a more*specific*claim about how one of these pathways computes the relevant test. That is, we claim a more narrow set of features of the input are important, and therefore*increase*the set of allowed interventions. This is necessary if we want to claim to understand the complete computation of the model, from inputs to outputs. However, it generally increases the loss of the scrubbed model if the additions to the hypothesis are imperfect.**Accuracy:***Improving*our hypothesis to more accurately match what the model computes. This often involves adjusting the features computed by our interpretation $I$. If done correctly this should decrease the loss of the scrubbed model.

A third way to iterate the hypothesis would be to make it more *comprehensive*, either by including paths through the model that were previously claimed to be unimportant or by being more restrictive in the swaps allowed for a particular intermediate. This should generally decrease the loss. We don’t highlight this type of improvement in the document, although it was a part of the research process as we discovered which pathways were necessary to include in our explanation.

Our experiments build upon one another in the following way:

The results of the experiments are summarized in this table, which may be helpful to refer back to as we discuss the experiments individually.

# | Summary of claimed hypothesis | Loss ± Std. Error | % loss recovered | Accuracy |

0a | The normal, unscrubbed, model | 0.0003 | 100% | 100% |

0b | Randomized baseline | 4.30 ± 0.12 | 0% | 61% |

1a | 1.0 and 2.0 compute the count test, 2.1 computes the horizon test, they are ANDed | 0.52 ± 0.04 | 88% | 88% |

1b | 1a but using the count^{(} test | 0.30 ± 0.03 | 93% | 91% |

2a | More specific version of 1b, where we specify the inputs to 1.0 and 2.0 | 0.55 ± 0.04 | 88% | 87% |

2b | 2a but using the ɸ approximation for the output of 0.0 | 0.53 ± 0.04 | 88% | 87% |

3a | More specific version of 1b, where we break up the inputs to 2.1 by sequence position | 0.97 ± 0.06 | 77% | 85% |

3b | 3a but using p_{adj} | 0.68 ± 0.05 | 84% | 88% |

3c | 3a plus specifying the inputs to 2.1 at each sequence position | 0.69 ± 0.05 | 84% | 87% |

3d | 3a but sampling a1 at each sequence position randomly | 0.81 ± 0.05 | 81% | 87% |

4 | Including both 2b and 3b | 1.22 ± 0.07 | 72% | 82% |

(% loss recovered is defined to be 1 - (experiment loss - 0a loss) / 0b loss. This normalizes the loss to be between 0% and 100%, where higher numbers are better.)

All experiments are run on 2000 scrubbed inputs, sampled according to the algorithm from 100,000 sequences of parentheses.

Running the model itself results in a loss of 0.0003 (100% accuracy) on this dataset. If you shuffle the labels randomly, this results in a loss of 4.30 (61% accuracy – recall the dataset is mostly unbalanced).

These can both be formalized as trivial hypotheses, as depicted in the diagram below. We hypothesize an interpretation with a single node, which corresponds (via ) to the entire model. The computational graph of , labeled with the model component it corresponds to, is shown below in black. The proposed feature computed by the node of I is annotated in red.

Note that in both cases we don’t split up our model into paths (aka ‘treeify’ it), meaning we will not perform any internal swaps.

In experiment 0a, we claim the output of the entire model encodes information about whether a given sequence is balanced. This means that we can swap the output of the model only if the label agrees: that is, the output on one balanced sequence for another balanced sequence. This will of course give the same loss as running the model on the dataset.

For 0b, we no longer claim *any *correspondence for this output. We thus swap the outputs randomly among the dataset. This is equivalent to shuffling the labels before evaluating the loss. We call such nodes (where any swap of their output is permitted) ‘unimportant’ and generally don’t include them in correspondence diagrams.

These experiments are useful baselines, and are used to calculate the % loss recovered metric.

We __claimed__ that the output of 1.0 and 2.0 each correspond to the count^{(} test, and the output of 2.1 corresponds to the horizon test. Let’s check this now. In fact, we will defend a slightly more specific claim: that the direct connection^{[5]} of these heads to the input of the final layer norm corresponds to the count^{(} test.

To start, we’ll first test a simple hypothesis: that 1.0 and 2.0 just implement the simple Equal Count Test (notably, not the count^{(} test) and 2.1 implements the horizon test, without checking whether the sequence starts with an open parenthesis.

We can draw this claimed hypothesis in the following diagram (for the remainder of this doc we won’t be drawing the inputs explicitly, to reduce clutter. Any node drawn as a leaf will have a single upstream input.):

How do we apply causal scrubbing to test this hypothesis? Let’s walk through applying it for a single data point (a batch size of 1). We apply the causal scrubbing algorithm to this hypothesis and our model. This will choose 5 different input data points from __the dataset described above__, which we will use to run the tree-ified model on as shown below:

**x**. We compute the loss of the scrubbed model from the true label of x_{ref}, or the reference input_{ref}. However, we will never run the scrubbed model on it; all inputs to the scrubbed model will be replaced with one of the other sampled inputs.- Our hypothesis claims that, if we replace the output of 1.0 or of 2.0 with its output on some input x’ that agrees with x
_{ref}on the count test, then the output will agree with x_{ref}on the balanced test. Therefore we**sample random x**. (Note that this means x_{1.0}and x_{2.0}which each agree with x_{ref}on the count test_{1.0}and x_{2.0}agree with*each other*on this test as well, despite being separate inputs.) - Similarly,
**x**._{2.1}is sampled randomly conditional on agreeing with x_{ref}on the horizon test **x**. The subtree of the model which is rooted at the output and omits the branches included in the hypothesis–that is, the branches going directly to 1.0, 2.0, and 2.1–is run on this example._{rand}is a random dataset example

We perform the above sample-and-run process many times to calculate the overall loss. We find that the scrubbed model recovers 88% of the original loss. The scrubbed model is very biased towards predicting unbalanced, with loss of 0.25 on unbalanced samples and 1.31 on balanced samples.

This, however, was still testing if 1.0 and 2.0 care about only the equal count test! As described above we believe it is more accurate to say that they check the count^{(} test, testing that the first parenthesis is open as well as performing the Equal Count Test.

Consider the set of inputs that pass the equal count test but fail the count^{(} test.^{[6]} Let us call these the fails-start-open set. If we return to the attribution results from the __informal hypothesis__ we can get intuition about the model’s behavior on these inputs:

The only difference from the left subfigure to the right subfigure is that the points in the fails-start-open set fail the more specific count^{(} test. We see that indeed, the output of head 2.0 is unbalanced on these inputs. Comparing the right and left diagrams we see the count^{(} test more cleanly predicts the output of head 2.0. The results for head 1.0 are similar.

Here is an updated hypothesis claim:

Compared to experiment 1a, the intervention is different in two ways:

- When the reference input is balanced, we may no longer sample x
_{1.0}or x_{2.0}from the fails-start-open set. To the extent that our hypothesis is right and such inputs x_{1.0}or x_{2.0}cause 1.0 or 2.0 to output more “evidence of unbalance”, this change will improve our loss on balanced sequences. Eyeballing the plots above, we do expect this to happen. - When the reference input is unbalanced, we may now sample x
_{1.0}and x_{2.0}from the fails-start-open set. To the extent that our hypothesis is wrong and such inputs cause 1.0 or 2.0, respectively, to output*less*“evidence of unbalance”, this change will harm our loss on unbalanced sequences. Eyeballing the plots above, it is somewhat hard to tell whether we should expect this: these points do have some evidence of unbalance, but it is unclear how the magnitude compares to that of the fail-count set.

The scrubbed model recovers 93% of the loss. In particular the loss on balanced goes way down (1.31 -> 0.65) while the loss on unbalanced is slightly lower (0.25 -> 0.18). Thus this experiment supports our previous interpretability result that 1.0 and 2.0 detect whether the first parenthesis is open.

Comparing experiments 1a and 1b makes it clear the count^{(} test is an improvement. However, it is worth noticing that if one had only the results of experiment 1a, it would not be clear that such an improvement needed to be made.^{[7]} In general, causal scrubbing is not legibly punishing when the feature you claimed correspondance with is highly correlated with the ‘true feature’ that the component of the model is in fact picking up. We expect that all our claims, while highly correlated with the truth, will miss some nuance in the exact boundaries represented by the model.

To make the above hypothesis more specific, we’ll explain how 1.0 and 2.0 compute the count^{(} test: in particular, they use the output of 0.0 at position 1. To test this, we update the hypothesis from 1a to say that 2.0 and 1.0 only depend on whether the first parenthesis is open and whether 0.0 is run on an input that passes the count^{(} test. The other inputs to the subtrees rooted at 1.0 and 2.0 don’t matter. We aren’t stating *how* the output of 0.0 and the embedding reach those later heads; we’re considering the indirect effect, i.e. via all possible paths, rather than just the direct effect that passes through no other nodes in the model.

Our claimed hypothesis is shown below. Recall that we do not show unimportant nodes, we annotate the nodes of I with the nodes of G that they correspond to, and we annotate the edges with the feature that that node of I computes:^{[8]}

This hypothesis will result in the following treeified model:

How do we determine the 5 input data points annotated in blue? Following the causal scrubbing algorithm, we first fix x_{ref} at the output. We then recursively move through the hypothesis diagram, sampling a dataset example for *every *node such that it agrees with the downstream nodes on the labeled features. The non-leaf node data points can then be discarded, as they are not used as inputs in the scrubbed model. This is depicted below:

In particular, we will first choose a dataset sample x_{ref} whose label we will use to evaluate the loss of the scrubbed model. Then we will select the input datasets as follows:

- x
_{out}agrees with x_{ref}on the balanced test. - x
_{2.1}agrees with x_{ref}on the never beneath horizon test (as before) - Both x
_{1.0}and x_{2.0}agree with x_{ref}on the count^{(}test. - x
_{0.0→1.0}agrees with x_{1.0}on the count^{(}test, and similarly for 2.0. - x
_{emb→1.0}agrees with x_{1.0}on whether the sequence starts with (, and similarly for 2.0.

Note that we do not require any other agreement between inputs! For example, x_{emb → 1.0} could be an input that fails the count test.

The 5 inputs in orange are claimed to be unimportant by our hypothesis. This means we will sample them randomly. We do, however, use the *same* random value for all unimportant inputs to a particular node in our model. For instance, there are many ‘unimportant’ inputs to the final layer norm: all three mlps, attention layer 0, and head 1.1. All of these are sampled together. Meanwhile, we sample these nodes *separately* from unimportant inputs to other nodes (e.g. the non-position 1 inputs to head 2.0); see the __appendix__ for some discussion of this.

The scrubbed model recovers 88% of the loss. Compared to experiment 1b, the loss recovered is significantly lower: a sign that we lost some explanatory power with this more specific hypothesis. By asserting these more specific claims, however, we still recover a large portion of the original loss. Overall, we think this result provides evidence that our hypothesis is a reasonable approximation for the 2.0 circuit. (To rule out the possibility that 0.0 is not important at all in these paths, we also ran an experiment replacing its output with that on a random input; this recovered 84% of the loss which is significantly less).

In fact, we believe that the proportion of open parentheses is encoded linearly in a particular direction. For more details and a precise definition of how we think this is done, see the __appendix__. The takeaway, however, is that we have a function ɸ which maps a value of p to a predicted activation of 0.0. We can thus rewrite the output of 0.0 in our model as the sum of two terms: ɸ(p) and the residual (the error of this estimate). We then claim that the ɸ(p) term is the important one. In essence, this allows swapping around the residuals between any two inputs, while ɸ(p) can only be swapped between inputs that agree on the count^{(} test. As a hypothesis diagram, this is:

Which leads to the following treeified model (again, with unimportant nodes in orange):

This results in a loss of 0.56, with accuracy 87%. This is basically unchanged from experiment 2a, giving evidence that we were correct in how p values are translated into the output of 0.0. Importantly, however, if we were somewhat wrong about *which* p values head 0.0 outputs on each input, this would have already hurt us in experiment 2a. Thus this result shouldn’t increase our confidence on that account.

For this experiment, we are not including the breakdown of 2.0 from experiment 2. We will add these back in for experiment 4, but it is simpler to consider them separately for now.

From __previous interpretability work__, we think that 0.0 computes the proportion of open parentheses in the suffix-substring starting at each query position. Then, mlp0 and mlp1 check the not-beneath-horizon test at that particular position. This means that 2.1 needs to ensure the check passes at every position (in practice, the attention pattern will focus on failed positions, which cause the head to output in an unbalanced direction).

We test this by sampling the input to head 2.1 at every sequence position separately (with some constraints, discussed below). This corresponds to the following hypothesis:

where x_{2}[i] denotes the input to attention 2 at position i, and n is the number of parentheses in the sequence. In particular, we fix n per example when we choose the dataset x_{2.1}. We additionally decide that x_{2}[i] must be at least i-parentheses long for all i <= n to avoid OOD edge cases that we didn’t wish to make claims about e.g. samples including multiple `[END]`

tokens (possibly a weaker constraint would be sufficient, but we have not experimented with that).

One other subtlety is what to do with the last sequence position, where the input is a special `[END]`

token. We discovered that this position of the input to 2.1 carries some information about the last parenthesis.^{[9]} We allow interchanges between different `[END]`

positions as long as they agree on the last parenthesis. This is equivalent to requiring agreement on both the horizon_{n} test and that the sequence is exactly len_{n}.

The causal scrubbing algorithm is especially strict when testing this hypothesis. Since 2.1 checks for *any* failure, a failure at a single input sequence position should be enough to cause it to output unbalance. In fact our horizon_{i} condition is not quite true to what the model is doing, i.e. 2.1 is able to detect unbalanced sequences based on input at position i even if the horizon_{i} test passes. Even if the horizon_{i }condition is most of what is going on, we are likely to sample at least one of these alternative failure detections because we sample up to 40 independent inputs, leading head 2.1 to output unbalance most of the time!

The overall loss recovered from doing this scrubbing is 77%. The model is again highly skewed towards predicting unbalanced, with a loss of 3.61 on balanced labels.

We can improve this performance somewhat by shifting our notion of horizon_{i }to one closer to what the model computes. In particular our current notion assumes the attention pattern of 0.0 is perfectly upper triangular (each query position pays attention evenly across all later key positions). Instead, it is somewhat more accurate to describe it as ‘quasi upper triangular’: it pays *more* attention to the upper triangular positions, but not exclusively. This relaxed assumption gives rise to a new “adjusted p” value that we can substitute for p in our interpretation; see the __appendix__. It turns out the new $I$ still correctly computes if an input is balanced or not.

Using this new hypothesis improves our loss recovery to 84%, a notable increase from experiment 3a. Breaking up the loss by balanced and unbalanced reference sequences, we see that the loss decreased specifically on the balanced ones.

We additionally ran experiments where we split up the input x_{2}[i] into terms and specified how it was computed by the MLPs reading from a0 (similar to experiment 2). Counterintuitively, this *decreases* the loss.

In general, causal scrubbing samples inputs separately when permitted by the hypothesis. This, however, is a case where sampling *together* is worse for the scrubbed performance.

More detail on these experiments, and their implications, can be found in __the appendix__.

We can combine our hypotheses about what 2.0 is doing (experiment 2b) and what 2.1 is doing (experiment 3b) into a single hypothesis:

This results in 72% loss recovered. We note that the *loss* is roughly additive: the loss of the scrubbed model in this experiment is roughly the sum of the losses of the two previous experiments.

There are still many ways our hypothesis could be improved. One way would be to make it more comprehensive, by understanding and incorporating additional paths through the model. For example, we have some initial evidence that head 1.1 can recognize some horizon failures and copy this information to the residual stream at position 1, causing head 2.0 to output the sequence is unbalanced. This path is claimed to be unimportant in Experiment 2, which likely causes some of the loss increase (and corresponding decrease in % loss recovered).

Additionally, the hypothesis could be made more specific. For instance in the __appendix__ we make more specific claims about exactly how head 0.0 computes p; these claims would be possible to test with causal scrubbing, although we have not done so. Similarly, it would be possible to test very specific claims about how the count or horizon_{i} test is computed from head 0.0, even at the level of which neurons are involved. In particular, the current hypothesized explanation for 2.1’s input is especially vague; replicating the techniques from experiment 2 on these inputs would be a clear improvement.

Another direction we could expand on this work would be to more greatly prioritize *accuracy* of our hypothesis, even if it comes at the cost of interpretability. In this project we have kept to a more abstract and interpretable understanding of the model’s computation. In particular for head 0.0 we have approximated its attention pattern, assuming it is (mostly) upper triangular. We could also imagine moving further in the direction we did with p_{adj} and estimating the attention probabilities 0.0 will have position by position. This would more accurately match the (imperfect) heuristics the model depends on, which could be useful for predicting adversarial examples for the model. For an example of incorporating heuristics into a causal scrubbing hypothesis, see __our results on induction in a language model__.

Overall, we were able to use causal scrubbing to get some evidence validating our original interpretability hypothesis, and recover the majority of the loss. We were also able to demonstrate that some very specific scrubs are feasible in practice, for instance rewriting the output of 0.0 at position 1 as the sum of ɸ(p) and a residual.

Using causal scrubbing led us to a better understanding of the model. Improving our score required refinements like using the adjusted open proportion or including that the end-token sequence position can carry evidence of unbalance to 2.1 in our hypothesis.

This work also highlighted some of the challenges of applying causal scrubbing. One recurring challenge was that scores are not obviously good or bad, only better or worse relative to others. For example, in our dataset there are many features that could be used to distinguish balanced and unbalanced sequences; this correlation made it hard to notice when we specified a subtly wrong feature of our dataset, as discussed when comparing experiments 1a and 1b, since the score was not obviously bad. This is not fundamentally a problem–we did in fact capture a lot of what our model was doing, and our score was reflective of that–but we found these small imprecisions in our understanding added up as we made our hypothesis more specific.

We also saw how, in some cases, our intuitions about how well we understood the model did not correspond to the loss recovered by our scrubbed model. Sometimes the scrubbed model’s loss was especially sensitive to certain parts (for instance, unbalanced evidence in the input to head 2.1 at a single sequence position) which can be punishing if the hypothesis isn’t perfectly accurate. Other times we would incorporate what we expected to be a noticeable improvement and find it made little difference to the overall loss.

Conversely, for experiments 3c and 3d (discussed in the __appendix__) we saw the scrubbed model’s loss decrease for what we ultimately believe are unjustified reasons, highlighting the need for something like adversarial validation in order to have confidence that a hypothesis is actually good.

In general, however, we think that these results provide some evidence that the causal scrubbing framework can be used to validate interpretability results produced by more ad-hoc methods. While formalizing the informal claims into testable hypotheses takes some work, the causal framework is remarkably expressive.

Additionally, even if the causal scrubbing method only validates claims, instead of producing them, we are excited about the role it will play in future efforts to explain model behaviors. Having a flexible but consistent language to express claims about what a model is doing has many advantages for easily communicating and checking many different variations of a hypothesis. We expect these advantages to only increase as we build better tools for easily expressing and iterating on hypotheses.

This model was trained with binary cross-entropy loss on a class-balanced dataset of 100k sequences of open and close parens, with labels indicating whether the sequence was balanced. The tokenizer has 5 tokens: `( ) [BEGIN] [END] [PAD]`

. The token at position 0 is always `[BEGIN]`

, followed by up to 40 open or close paren tokens, then `[END]`

, then padding until length 42.

The original dataset was a mixture of several different datasets with binary cross entropy loss:

- Most of the training data (~92%) was randomly-generated sequences with an even number of parentheses. Balanced sequences were upsampled to be about 26% of the dataset.
- Special case datasets (the empty input, odd length inputs)
- Tricky sequences, which were adversarial examples for other models.

For the experiments in this writeup we will only use the first dataset of randomly generated inputs. We are attempting to explain the behavior of “how does this model get low cross-entropy loss on this randomly-generated dataset.” This may require a subtly different explanation than how it predicts more difficult examples.

We plan to release our code and will link it here when available. Note that the computation depends on our in-house tensor-computation library, expect it to be time consuming to understand the details of what is being computed. Feel free to get in contact if it is important for you to understand such things.

In our previous post, we __discussed__ reasons for sampling unimportant inputs to a node in our model (specifically, in ) *separately* from unimportant inputs to other nodes in our correspondence.

In this work, this was very important for reasoning about what correlations exist between different inputs and interpreting the results of our experiments. Consider the hypotheses __3c and 3d__. If we claim that the inputs to 2.1 at each position carry information about the horizon_{i} test, then each a1_{i} will be sampled separately. If we claimed instead that only the mlps and a0 had that job, and a1 was unimportant, we would still like each a1_{i} to be sampled separately! That way the two claims differ *only* in whether a1 is sampled conditional on the horizon_{i} test, and not in whether the a1_{i} are drawn from the same input.

In those experiments we discuss how the correlation between inputs hurt the loss of our scrubbed model. In fact, experimentally we found that if we ran 3d but sampled a1 across positions together, it hurt our scrubbed model’s loss. If we had run this experiment alone, without running 3d, the effects of “sampling a1[i] separately from the other terms at position i” and “sampling the a1[i] all together” would be confounded.

So, sampling unimportant inputs separately is especially important for comparing the swaps induced by hypotheses 3c and 3d cleanly. The choice makes minimal difference elsewhere.

The attention pattern of layer zero heads is a relatively simple function of the input. Recall that for every (query, key) pair of positions we compute an attention score. We then divide by a constant and take the query-axis softmax so that the attention paid by every query position sums to one. For layer 0 heads, each attention score is simply a function of four inputs: the query token, the query position, the key token, and the key position:

One pattern that is noticeable is that if the query is an open parentheses, the attention score does not change based on if the key token is an open or close parentheses. That is, for all possible query and key positions.

This means that the attention pattern at an open parentheses query will only depend on the query position and the length of the entire sequence. The expected attention pattern (after softmax) for sequences of length 40 is displayed below:

And focusing on three representative query positions (each one a row in the above plot):

Some things to notice:

- To a first approximation, the attention is roughly upper triangular.
- The attention before the query position is non-zero, but mostly flat. This will motivate our definition of p
_{adj}for experiment 3b. - There are various imperfections. We expect these are some of the reasons our model has non-perfect performance.

As a simplifying assumption, let us assume that the attention pattern *is *perfectly upper triangular. That is every query position pays attention to itself and all later positions in the sequence. What then would the head output?

One way to compute the output of an attention head is to first multiply each input position by the V and O matrices, and then take a weighted average of these with weights given by the attention probabilities. It is thus useful to consider the values

before this weighted average.

It turns out that depends strongly on if k_{tok} is an open or close parentheses, but doesn’t depend on the position i. That is we can define and to be the mean across positions. All point in the direction of (minimum cosine similarity is 0.994), and all point in the direction of (minimum cosine similarity is 0.995). and , however, point in opposite directions (cosine similarity -0.995).

We can combine what we have learned about the attention and the effect of the O and V matrices to give a reasonable understanding of the behavior of this head. Let us assume the attention pattern is perfectly upper triangular. Then at query position i, p_{i} of the attention will be on open parentheses positions, and (1-p_{i}) of the attention will be on close parentheses positions.^{[10]} Then the output of the head will be well approximated by

Since these terms are in nearly-opposing directions, we can well approximate the activation in a rank-one subspace:

This shows how 0.0 computes p_{i} at open parenthesis positions. We also directly test this ɸ function in experiment 2b.

In the previous appendix section we assumed the attention pattern of 0.0 is perfectly upper triangular. We did note that 0.0 pays non-zero attention to positions before the query.

Fix some query position q in an input of length n. We can split the string into a prefix and a suffix, where the prefix is positions [1, q-1] and the suffix is positions [q, n]. If 0.0 had a perfectly upper triangular attention pattern, it would pay 1/len(suffix) attention to every key position in the suffix.

Instead, however, let us assume that it pays b_{q,n} attention to the prefix, leaving only (1-b_{q,n}) attention for the suffix. Then it pays 1/len(prefix) attention to every position in the prefix, and (1-b_{q,n})/len(suffix) attention to every position in the suffix.

We calculate every b_{q,n} based on analysis of the attention pattern. Note these are independent of the sequence. Two important facts are true about these values:

- b
_{1,n}= 0. That is, at position 1 no attention is paid to the prefix, since no prefix exists. - This implies that at every position, for every sequence length, more attention is paid to a given position in the suffix than in the prefix.

We then define p_{adj, q} based on this hypothesized attention. If p_{prefix} and p_{suffix} are the proportion of open parentheses in the respective substrings, then

The count test is unchanged, since fact 1 above implies p_{adj,1} = p_{1}. The never-beneath-horizon test is altered: we now test horizon_{adj,i} which is defined to be true if p_{adj,i }≤ 0.5. While this doesn’t agree on all sequences, we will show it does agree for sequences that pass the count test. This is sufficient to show that our new $I$ always computes if a given input is balanced (since the value of the horizon test is unimportant if the count test fails).

Thus, to complete the proof, we will fix some input passes the count test and a query position q. We will show that the adjusted horizon test at q passes exactly if the normal horizon test at q passes.

We can express both p_{1} and p_{adj,q} as weighted averages of p_{prefix} and p_{suffix}. In particular,

However, b_{q n} < (q-1)/n. Thus, p_{adj,q} > p1 exactly when p_{suffix} > p_{prefix}. Since the input passes the count test, p^{1}=0.5 which implies only one of p_{suffix} and p_{prefix} can be greater than 0.5. Thus, a horizon failure at q ⇔ p_{suffix} > 0.5 ⇔ p_{suffix} > p_{prefix} ⇔ p_{adj, q} ⇔ an adjusted horizon failure at q.

This shows the horizon tests agree at every position of any input that passes the count test. This ensures they agree on if any input is balanced, and our new causal graph is still perfectly accurate.

For some evidence that the adjusted proportion more closely matches what 2.1 uses, we can return to __our measure of the logit difference to 2.1__. We might hope that the maximum value of p_{i} across the sequence has a clear correspondence with the output of 2.1. However, it turns out there are many sequences that end in an open parentheses (and thus p_{n}=1) but 2.1 does not clearly output an unbalanced signal, as can be seen in the left subplot below:

In practice, these are often sequences with many more close parentheses than open parentheses. Thus, even at the last position 0.0 attention will mostly be spread among positions with close parentheses. While this means 2.1 may not pick up on the failure, 2.0 will be able to detect these sequences as unbalanced.

This type of dynamic is captured in our definition for p_{adj}. We can see that the maximum adjusted proportion has a much clearer relationship with the attribution to head 2.1.

The plot above does not explain why our scrubbed model performs better when using p_{adj}; the lower loss comes from samples that are not on the maximum p or p_{adj} for the sequence. In particular the attribution plot has clearer separation of classes because we remove false-negatives of the original horizon test at the sequence level (horizon_{all} fails but 2.1 does not say the input is unbalanced; these are removed because horizon_{adj} passes). The main reason the scrubbed loss improves, however, is because we remove false-positives at the position level (horizon_{i} passes but 2.1 treats the input a failure; these are removed because horizon_{adj,i} fails).

Examples where horizon_{i} passes but horizon_{adj,i} fails are ones where there is a horizon failure somewhere in the prefix. Thus, there aren’t sequence level false positives of the horizon test (when compared to the adjusted horizon test). In practice the shortcomings of the normal horizon test seem to not be a problem for experiments 1 and 2.^{[11]} It is notably worse for experiment 3, however, where sampling a single x_{2}[i] that has unbalanced-evidence is enough to cause the model to flip from a confidently balanced to confidently unbalanced prediction.

In order to make our hypothesis 3a more specific we can claim that the only relevant parts of x_{2}[i]^{ }are the terms from attention 0, mlp0, and mlp1. We sample each of these to be from a separate sequence, where all three agree on horizon_{ i}. The rest of the sum (attention 1 and the embeddings) will thus be sampled on a random input.

Surprisingly, this causes the loss recovered by the scrubbed model to improve significantly when compared to experiment 3a, to 84%. Why is this? Shouldn’t claiming a more specific hypothesis result in lower loss recovered?

We saw in experiment 3a that certain inputs which pass the horizon test at i still carry unbalanced-signal within x_{2}[i]. However, instead of sampling a single input for x_{2}[i] we now are sampling four different inputs: one each for a0, mlp0, and mlp1 which all agree on the horizon_{ i} test, and a final random input for both the embedding and a1. Sampling the terms of x_{2}[i] is enough to ‘wash out’ this unbalanced signal in some cases.

In fact, it is sufficient to just sample x_{2}[i] as the sum of two terms. Consider the intermediate hypothesis that all that matters is the *sum* of the outputs of a0, mlp0, and mlp1:

By not including the emb + a1 term, this hypothesis implicitly causes them to be sampled from a separate random input.

The % loss recovered is 81%, between that of 3a and 3c. As a summary, the following table shows which terms are sampled together:

In red are the claimed-unimportant terms, sampled from a random input. All other inputs agree with horizon_{ i}. Note also that in each case, all inputs are independently drawn between positions.

Are the results of experiment 3c and 3d legitimate then? We think not. One way to think about this is that a hypothesis makes claims about what scrubbing should be legal. For instance, the hypothesis in 3d claims that it would be okay to sample x_{2}[i] separately, term by term. However, the hypothesis also implies that it would be okay to sample them together!

One way to address this sort of problem is to introduce an __adversary__, who can request that terms are sampled together (if this is allowable by the proposed hypothesis). The adversary would then request that the hypotheses from 3c and 3d are run with every term of x_{2}[i] sampled together. This would result in the same experiment as we ran in 3a.

^{^}In most of the interpretability research we’re currently doing, we focus on tasks that a simple algorithm can solve to make it easier to reason about the computation implemented by the model. We wanted to isolate the task of finding the clearest and most complete explanation as possible, and the task of validating it carefully. We believe this is a useful step towards understanding models that perform complex tasks; that said, interpretation of large language models involves additional challenges that will need to be surmounted.

^{^}We limit to just the random dataset mostly to make our lives easier. In general, it is also easier to explain a behavior that the model in fact has. Since the model struggles on the adversarial datasets, it would be significantly more difficult to explain ‘low loss on the full training distribution’ than ‘low loss on the random subset of the training distribution.’

It would be cleaner if the model was also exclusively trained on the random dataset. If redoing our investigation of this model, we would train it only on the random dataset.^{^}In particular, we can write the output of the model as f(x

_{2.0}+y), where f is the final layer norm and linear layer which outputs the log-probability that the input is balanced, x_{2.0}[0] is the output of head 2.0 at position 0, and y[0] is the sum of all other terms in the residual stream. Then we compute the attribution for 2.0 as where y’ is sampled by computing the sum of other terms on a random dataset sample. We do the same to get an attribution score for head 2.1. Other attribution methods such as linearizing layer norms give similar results.^{^}The same algorithm is applied to a small language model on induction

__here__, but keep in mind that some conventions in the notation are different. For example, in this post we more explicitly express our interpretation as a computational graph, while in that post we mostly hypothesize which paths in the treeified model are important; both are valid ways to express claims about what scrubs you ought to be able to perform without hurting the performance of the model too much. Additionally, since our hypothesis is that important inputs need not be equal, our treeified model is run on many more distinct inputs.^{^}“Direct connection” meaning the path through the residual stream, not passing through any other attention heads or MLPs.

^{^}These will be inputs like

`)(()`

with equal amounts of open and close parentheses, but a close parentheses first.^{^}One technique that can be helpful to discover these sorts of problems is to perform a pair of experiments where the only difference is if a particular component is scrubbed or not. This is a way to tell which scrubbed inputs were especially harmful – for instance, the fails-count-test inputs being used for 2.0 or 1.0 hurting the loss in 1a.

^{^}We exclude paths through attention layer 0 when creating the indirect emb node

^{^}We originally theorized this by performing an “minimal-patching experiment” where we only patched a single sequence position at a time and looking for patterns in the set of input datapoints that caused the scrubbed model to get high loss. In general this can be a useful technique to understand the flaws of a proposed hypothesis. Adding this fact to our hypothesis decreased our loss by about 2 SE.

^{^}This does ignore attention on the [BEGIN] and [END] positions, but in practice this doesn’t change the output noticeably.

^{^}Using the adjusted horizon test in Experiment 1b slightly increases the loss to 0.33, not a significant difference. It is perhaps somewhat surprising the loss doesn’t decrease. In particular we should see some improvement when we wanted to sample an output from 2.0 and 1.0 that passes the count

^{(}test, but an output from 2.1 that fails the horizon test, as we no longer sample false-negatives for 2.1 (where the output has no unbalanced evidence). This is a rare scenario, however: there aren’t many inputs in our dataset that are horizon failures but pass the count^{(}test. We hypothesize that this is why the improvement doesn’t appear in our overall loss.

Discuss]]>

*An appendix to **this post**.*

As mentioned above, our method allows us to explain quantitatively measured model behavior operationalized as the expectation of a function on a distribution .

Note that no part of our method distinguishes between the part of the input or computational graph that belongs to the “model” vs the “metric.”^{[1]}

It turns out that you can phrase a lot of mechanistic interpretability in this way. For example, here are some results obtained from attempting to explain how a model has low loss:

- Nanda and Lieberum’s
__analysis of the structure of a model that does modular addition__explains the observation that their model gets low loss on the validation dataset. - The
__indirect object identification circuit__explains the observation that the model gets low loss on the indirect object identification task, as measured on a synthetic distribution. - Induction circuits (as described in
__Elhage et al. 2021__) explain the observation that the model gets low loss when predicting tokens that follow the heuristic: “if AB has occurred before, then A is likely to be followed by B”.

That being said, you can set up experiments using other metrics besides loss as well:

- Cammarata et al identify
__curve detectors in the Inception vision model__by using the response of various filters on synthetic datasets to explain the correlation between: 1) the activation strength of some neuron, and 2) whether the orientation of an input curve is close to a reference angle.

If you’re trying to explain the expectation of , we always consider it a valid move to suggest an alternative function if on every input (__“extensional equality”__), and then explain instead. In particular, we’ll often start with our model’s computational graph and a simple interpretation, and then perform “algebraic rewrites” on both graphs to naturally specify the correspondence.

Common rewrites include:

- When the output of a single component of the model is used in different ways by different paths, we’ll duplicate that node in , such that each copy can correspond to a different part of .
- When multiple components of the model compute a single feature we can either:
- duplicate the node in , to sample the components
*separately*; or - combine the nodes of into a single node, to sample the components
*together.* - Sometimes, we want to test claims of the form “this subspace of the activation contains the feature of interest”. We can express this by rewriting the output as a sum of the activation projected into subspace and the orthogonal component. We can then propose that only the projected subspace encodes the feature.
- An even more complicated example is when we want to test a theorized function that maps from an input to a predicted activation of a component. We can then rewrite the output as the sum of two terms: and the residual (the error of the estimate), and then claim only the phi term contains important information. If your estimate is bad, the error term will be large in important ways. This is especially useful to test hypotheses about scalar quantities (instead of categorical ones).
^{[2]}

Note that there are many trivial or unenlightening algebraic rewrites. For example, you could always replace f’ with a lookup table of f, and in cases where the model performs perfectly, you can also replace f with the constant zero function. Causal scrubbing is *not* intended to generate mechanistic interpretations or ensure that only mechanistic interpretations are allowed, but instead to check that a given interpretation is faithful. We discuss this more in the limitations section of the main post.

We allow hypotheses at a wide variety of levels of specificity. For example, here are two potential interpretations of the same :

These interpretations correspond to the same input-output mappings, but the hypothesis on the right is more specific, because it's saying that there are three separate nodes in the graph expressing this computation instead of one. So when we construct to correspond to we would need three different activations that we claim are important in different ways, instead of just one for mapping to . In interpretability, we all-else-equal prefer more specific explanations, but defining that is out of scope here–we’re just trying to provide a way of looking at the predictions made by hypotheses, rather than expressing any a priori preference over them.

In both of these results posts, in order to measure the similarity between the scrubbed and unscrubbed models, we use *% loss recovered*.

As a baseline we use , the ‘randomized loss’, defined as the loss when we shuffle the connection between the correct labels and the model’s output. Note this randomized loss will be higher than the loss for a calibrated guess with no information. We use randomized loss as the baseline since we are interested in explaining why the model makes the guesses it makes. If we had no idea, we could propose the trivial correspondence that the model’s inputs and outputs are unrelated, for which .

Thus we define:

This percentage can exceed 100% or be negative. It is not very meaningful as a fraction, and is rather an arithmetic aid for comparing the magnitude of expected losses under various distributions. However, it is the case that hypotheses with a “% loss recovered” closer to 100% result in predictions that are more consistent with the model.

Above, we rate our hypotheses using the distance between the expectation under the dataset and the scrubbed distribution, .^{[3]}

You could instead rate hypotheses by comparing the full distribution of input-output behavior. That is, the difference between the distribution of the random variable under the data set , and under .

In this work, we prefer the expected loss. Suppose that one of the drivers of the model’s behavior is noise: trying to capture the full distribution would require us to explain what causes the noise. For example, you’d have to explain the behavior of a randomly initialized model despite the model doing ‘nothing interesting’.

Earlier, we noted our preference for “resampling ablation” of a component of a model (patch an activation of that component from a randomly selected input in the dataset) over zero or mean ablation of that component (set that component’s activation to 0 or its mean over the entire dataset, respectively) in order to test the claim “this component doesn’t matter for our explanation of the model”. We also mentioned three specific problems we see with using zero or mean ablation to test this claim. Here, we’ll discuss these problems in greater detail.

**1) Zero and mean ablations take your model off distribution in an unprincipled manner.**

The first problem we see with these ablations is that they destroy various properties of the distribution of activations in a way that seems unprincipled and could lead to the ablated model performing either worse or better than it should.

As an informal argument, imagine we have a module whose activations are in a two dimensional space. In the picture below we’ve drawn some of its activations as gray crosses, the mean as a green cross, and the zero as a red cross:

It seems to us that zero ablating takes your model out of distribution in an unprincipled way. (If the model was trained with dropout, it’s slightly more reasonable, but it’s rarely clear how a model actually handles dropout internally.) Mean ablating also takes the model out of distribution because the mean is not necessarily on the manifold of plausible activations.

**2) Zero and mean ablations can have unpredictable effects on measured performance.**

Another problem is that these ablations can have unpredictable effects on measured performance. For example, suppose that you’re looking at a regression model that happens to output larger answers when the activation from this module is at its mean activation (which, let’s suppose, is off-distribution and therefore unconstrained by SGD). Also, suppose you’re looking at it on a data distribution where this module is in fact unimportant. If you’re analyzing model performance on a data subdistribution where the model generally guesses too high, then mean ablation will make it look like ablating this module harms performance. If the model generally guesses too low on the subdistribution, mean ablation will improve performance. Both of these failure modes are avoided by using random patches, as resampling ablation does, instead of mean ablation.

**3) Zero and mean ablations remove variation that your model might depend on for performance.**

The final problem we see with these ablations is that they neglect the variation in the outputs of the module. Removing this variation doesn’t seem reasonable when claiming that the module doesn't matter.

For an illustrative toy example, suppose we’re trying to explain the performance of a model with three modules M1, M2, and M3. This model has been trained with dropout and usually only depends on components M1 and M2 to compute its output, but if dropout is active and knocks out M2, the model uses M3 instead and can perform almost as well as if it were able to use M1 and M2.

If we zero/mean ablate M2 (assume mean 0), it will look like M2 wasn't doing anything at all and our hypothesis that it wasn't relevant will be seemingly vindicated. If instead we resample ablate M2, the model will perform significantly worse (exactly how much worse is dependent on exactly how the output of M2 is relevant to the final output).

This example, while somewhat unrealistic, hopefully conveys our concern here: sometimes the variation in the outputs of a component is important to your model and performing mean or zero ablation forces this component to only act as a fixed bias term, which is unlikely to be representative of its true contribution to the model’s outputs.

We think these examples provide sufficient reasons to be skeptical about the validity of zero or mean ablation and demonstrate our rationale for preferring resampling ablation.

Suppose we have the following hypothesis where I maps to the nodes of G in blue:

There are four activations in that we claim are unimportant.

Causal scrubbing requires performing a resampling ablation on these activations. When doing so, should we pick one data point to get all four activations on? Two different data points, one for R and S (which both feed into V) and a different one for X and Y? Or four different data points?

In our opinion, all are reasonable experiments that correspond to subtly different hypotheses. This may not be something you considered when proposing your informal hypothesis, but following the causal scrubbing algorithm forces you to resolve this ambiguity. In particular, the more we sample unimportant activations independently, the more specific the hypothesis becomes, because it allows you to make strictly more swaps. It also sometimes makes it easier for the experimenter to reason about the correlations between different inputs. For a concrete example where this matters, see __the paren balance checker experiment__.

And so, in the pseudocode above we sample the pairs (R, S) and (X, Y) separately, although we allow hypotheses that require all unimportant inputs throughout the model to be sampled together.^{[4]}

Why not go more extreme, and sample every single unimportant node separately? One reason is that it is not well-defined: we can always rewrite our model to an equivalent one consisting of a different set of nodes, and this would lead to completely different sampling! Another is that we don’t actually intend this: we do believe it’s important that the inputs to our treeified model be “somewhat reasonable”, i.e. have some of the correlations that they usually do in the training distribution, though we’re not sure exactly which ones matter. So if we started from saying that all nodes are sampled separately, we’d immediately want to hypothesize something about them needing to be sampled together in order for our scrubbed model to not get very high loss. Thus this default makes it simpler to specify hypotheses.

In general we don’t require hypotheses to be surjective, meaning not all nodes of need to be mapped onto by , nor do we require that contains all edges of . This is convenient for expressing claims that some nodes (or edges) of are unimportant for the behavior. It leaves a degree of freedom, however, in how to treat these unimportant nodes, as discussed in the preceding section.

It is possible to remove this ambiguity by requiring that the correspondence be an isomorphism between and . In this section we’ll demonstrate how to do this in a way that is consistent with the pseudocode presented, by combining all the unimportant parents of each important node.

In the example below, both R and S are unimportant inputs to the node V, and both X and Y are unimportant inputs to the node Z. We make the following rewrites in the example below:

- If a single important node has multiple unimportant inputs, we combine them. This forms the new node (X, Y) in G
_{2}. We also combine all upstream nodes, such that there is a single path from the input to this new combined node, forming (T, U) which (X, Y) depends on. This ensures we’ll only sample one input for all of them in the treeified model. - We do the same for (R, S) into node V.
- Then we extend with new nodes to match the entirety of rewritten . For all of these new nodes that correspond to unimportant nodes (or nodes upstream of unimportant nodes), our interpretation says that all inputs map to a single value (the
__unit type__). This ensures that we can sample any input. - While we also draw the edges to match the structure of the rewritten , we will not have other nodes in be sensitive to the values of these unit nodes.

If you want to take a different approach to sampling the unimportant inputs, you can rewrite the graphs in a different way (for instance, keeping X and Y as separate nodes).

One general lesson from this is that rewriting the computational graphs and is extremely expressive. In practice, we have found that with some care it allows us to run the experiments we intuitively wanted to.

Suppose we have a function to which we want to apply the causal scrubbing algorithm. Consider an isomorphic (see above) __treeified hypothesis__ for . In this appendix we will show that causal scrubbing preserves the joint distribution of inputs to each node of (Lemma 1). Then we show that the distribution of *inputs* induced by causal scrubbing is the maximum entropy distribution satisfying this constraint (Theorem 2).

Let be the domain of and be the input distribution for (a distribution on ). Let be the distribution given by the causal scrubbing algorithm (so the domain of is , where is the number of times that the input is repeated in ).

We find it useful to define two sets of random variables: one set for the values of wires (i.e. edges) in when is run on a consistent input drawn from (i.e. on for some ); and one set for the values of wires in induced by the causal scrubbing algorithm:

**Definition** (-consistent random variables): For all the edges of , we call the “-consistent random variable” the result of evaluating the interpretation on , for a random input . For each node , we will speak of the joint distribution of its input wires, and call the resulting random variable the “-consistent inputs (to )”. We also refer to the value of the wire going out of as the “-consistent output (to )”.

**Definition** (scrubbed random variables): Suppose that we run on . In the same way, this defines a set of random variables, which we call the *scrubbed* random variables (and use the terms "scrubbed inputs" and "scrubbed output" accordingly).

**Lemma 1:** For every node , the joint distribution of scrubbed inputs to is equal to the product distribution of -consistent inputs to .

**Proof:** Recall that the causal scrubbing algorithm assigns a datum in to every node of , starting from the root and moving up. The key observation is that for every node of , the distribution of the datum of is exactly . We can see this by induction. Clearly this is true for the root. Now, consider an arbitrary non-root node and assume that this claim is true for the parent of . Consider the equivalence classes on defined as follows: and are equivalent if has the same value at as when is run on each input. Then the datum of is chosen by sampling from subject to being in the same equivalence class as the datum of . Since (by assumption) the datum of is distributed according to , so is the datum of .

Now, by the definition of the causal scrubbing algorithm, for every node , the scrubbed inputs to are equal to the inputs to when is run on the datum of . Since the datum of is distributed according to , it follows that the joint distribution of scrubbed inputs to is equal to the joint distribution of -consistent inputs to .

**Theorem 2:** The joint distribution of (top-level) scrubbed inputs is the maximum-entropy distribution on , subject to the constraints imposed by Lemma 1.

**Proof:** We proceed by induction on a stronger statement: consider any way to "cut" through in a way that separates all of the inputs to from the root (and does so minimally, i.e. if any edge is un-cut then there is a path from some leaf to the root). (See below for an example.) Then the joint scrubbed distribution of the cut wires has maximal entropy subject to the constraints imposed by Lemma 1 on the joint distribution of scrubbed inputs to all nodes lying on the root's side of the cut.

Our base case is the cut through the input wires to the root (in which case Theorem 2 is vacuously true). Our inductive step will take any cut and move it up through some node , so that if previously the cut passed through the output of , it will now pass through the inputs of . We will show that if the original cut satisfies our claim, then so will the new one.

Consider any cut and let be the node through which we will move the cut up. Let denote the vector of inputs to , be the output of (so ), and denote the values along all cut wires besides . Note that and are independent conditional on ; this follows by conditional independence rules on Bayesian networks ( and are -separated by ).

Next, we show that this distribution is the maximum-entropy distribution. The following equality holds for *any* random variables such that is a function of :

Where is mutual information. The first step follows from the fact that is a function of . The second step follows from the identity . The third step follows from the identity that . The last step follows from the fact that , again because is a function of .

Now, consider all possible distributions of subject to the constraints imposed by Lemma 1 on the joint distribution of scrubbed inputs to all nodes lying on the root's side of the updated cut. The lemma specifies the distribution of and (therefore) . Thus, subject to these constraints, is equal to plus , which is a constant. By the inductive hypothesis, is as large as possible subject to the lemma's constraints. Mutual information is non-negative, so it follows that if , then is as large as possible subject to the aforementioned constraints. Since and are independent conditional on , this is indeed the case.

This concludes the induction. So far we have only proven that the joint distribution of scrubbed inputs is *some* maximum-entropy distribution subject to the lemma's constraints. Is this distribution unique? Assuming that the space of possible inputs is finite (which it is if we're doing things on computers), the answer is yes: entropy is a strictly concave function and the constraints imposed by the lemma on the distribution of scrubbed inputs are convex (linear, in particular). A strictly concave function has a unique maximum on a convex set. This concludes the proof.

**Fun Fact 3:** The entropy of the joint distribution of scrubbed inputs is equal to the entropy of the output of , plus the sum over all nodes of the information lost by (i.e. the entropy of the joint input to minus the entropy of the output). (By Lemma 1, this number does not depend on whether we imagine being fed -consistent inputs or scrubbed inputs.) By direct consequence of the proof of Theorem 2, we have (with as in the proof of Theorem 2). Proceeding by the same induction as in Theorem 2 yields this fact.

In __our polysemanticity toy model paper__, we introduced an analytically tractable setting where the optimal model represents features in superposition. In this section, we’ll analyze this model using causal scrubbing, as an example of what it looks like to handle polysemantic activations.

The simplest form of this model is the two-variable, one-neuron case, where we have independent variables x1 and x2 which both have zero expectation and unit variance, and we are choosing the parameters c and d to minimize loss in the following setting:

Where is our model, and are the parameters we’re optimizing, and and are part of the task definition. As discussed in our toy model paper, in some cases (when you have some combination of a and b having similar values and and having high kurtosis (e.g. because they are usually equal to zero)), c and d will both be set to nonzero values, and so can be thought of as a superposed representation of both and .

To explain the performance of this model with causal scrubbing, we take advantage of function extensionality and expand y_tilde:

And then we explain it with the following hypothesis:

When we sample outputs using our algorithm here, we’re going to sample the interference term from random other examples. And so the scrubbed model will have roughly the same estimated loss as the original model–the errors due to interference will no longer appear on the examples that actually suffer from interference, but the average effect of interference will be approximately reproduced.

In general, this is our strategy for explaining polysemantic models: we do an algebraic rewrite on the model so that the model now has monosemantic components and an error term, and then we say that the monosemantic components explain why the model is able to do the computation that it does, and we say that we don’t have any explanation for the error term.

This works as long as the error is actually unstructured–if the model was actively compensating for the interference errors (as in, doing something in a way that correlates with the interference errors to reduce their cost), we’d need to describe that in the explanation in order to capture the true loss.

This strategy also works if you have more neurons and more variables–we’ll again write our model as a sum of many monosemantic components and a residual. And it’s also what we’d do with real models–we take our MLP or other nonlinear components and make many copies of the set of neurons that are required for computing a particular feature.

This strategy means that we generally have to consider an explanation that’s as large as the model would be if we expanded it to be monosemantic. But it’s hard to see how we could have possibly avoided this.

Note that this isn’t a solution to *finding* a monosemantic basis - we’re just claiming that if you had a hypothesized monosemantic reformulation of the model you could test it with causal scrubbing.

This might feel vacuous–what did we achieve by rewriting our model as if it was monosemantic and then adding an error term? We claim that this is actually what we wanted. The hypothesis explained the loss because the model actually was representing the two input variables in a superposed fashion and resigning itself to the random error due to interference. The success of this hypothesis reassures us that the model isn’t doing anything more complicated than that. For example, if the model was taking advantage of some relationship between these features that we don’t understand, then this hypothesis would not replicate the loss of the model.

Now, suppose we rewrite the model from the form we used above:

To the following form:

Where we’ve split the noise term into two pieces. If we sample these two parts of the noise term independently, we will have effectively reduced the magnitude of the noise, for the usual reason that averages of two samples from a random variable have lower variance than single samples. And so if we ignore this correlation, we’ll estimate the cost of the noise to be lower than it is for the real model. This is another mechanism by which ignoring a correlation can cause the model to seem to perform better than the real model does; as before, this error gives us the opportunity to neglect some positive contribution to performance elsewhere in the model.

We can construct cases where the explanation can make the model look better by sneaking in information. For example, consider the following setting:

The model’s input is a tuple of a natural number and the current game setting, which is either EASY or HARD (with equal frequency). The model outputs the answer either “0”, “1”, or “I don’t know”. The task is to guess the last bit of the hash of the number.

Here’s the reward function for this task:

Game mode | Score if model is correct | Score if model is incorrect | Score if model says “I don’t know” |

EASY | 2 | -1 | 0 |

HARD | 10 | -20 | 0 |

If the model has no idea how to hash numbers, its optimal strategy is to guess when in EASY mode and say “I don’t know” in HARD mode.

Now, suppose we propose the hypothesis that claims that the model outputs:

- on an EASY mode input, what the model would guess; and
- on a HARD mode input, the correct answer.

To apply causal scrubbing, we consider the computational graph of both the model and the hypothesis to consist of the input nodes and a single output node. In this limited setting, the projected model runs the following algorithm:

- Replace the input with a random input that would give the same answer according to the hypothesis; and
- Output what the model outputs on that random input.

Now consider running the projected model on a HARD case. According to the hypothesis, we output the correct answer, so we replace the input

- half the time with another HARD mode input (with the same answer), on which the model outputs “I don’t know”; and
- half the time with an EASY mode input chosen such the model will guess the correct answer.

So, when you do causal scrubbing on HARD cases, the projected model will now guess correctly half the time, because half its “I don’t know” answers will be transformed into the correct answer. The projected model’s performance will be worse on the EASY cases, but the HARD cases mattered much more, so the projected model’s performance will be much better than the original model’s performance, even though the explanation is wrong!

In examples like this one, hypotheses can cheat and get great scores while being very false.

(Credit for the ideas in this section is largely due to ARC.)

We might have hoped that we’d be able to use causal scrubbing as a check on our hypotheses analogous to using a proof checker like Lean or Coq to check our mathematical proofs, but this doesn’t work. Our guess is that it’s probably impossible to have an efficient algorithm for checking interpretability explanations which always rejects false explanations. This is mostly because we suspect that interpretability explanations should be regarded as an example of __defeasible reasoning__. Checking interpretations in a way that rejects all false explanations is probably NP-hard, and so we want to choose a notion of checking which is weaker.

We aren’t going to be able to check hypotheses by treating as uncorrelated everything that the hypotheses claimed wasn’t relevantly correlated. This would have worked if ignoring correlations could only harm the model. But as shown above, we have several cases where ignoring correlations helps the model.

So we can’t produce true explanations by finding hypotheses subject to the constraint that they predict the observed metrics. As an alternative proposal, we can check if hypotheses are comprehensive by seeing if any adversarial additions to the hypothesis would cause the predicted metric to change considerably. In all of the counterexamples above, the problem is that the metric was being overestimated because there were important correlations that were being neglected and which would reduce the estimated metric if they were included. If we explicitly check for additional details to add to our hypotheses which cause the estimated metric to change, all the counterexamples listed above are solved.

To set up this adversarial validation scheme, we need some mechanism for hypotheses to be constructed adversarially. That is, we need to handle cases where the adversary wants to rewrite f to an extensionally-equal function. One way of thinking about this is that we want a function `join` which is a binary operation on hypotheses, taking the two hypotheses to the hypothesis which preserves all structure in the model that either of the two hypotheses preserved.

Here are two ways of defining this operation:

**Swap-centric.**You can think of a hypothesis as a predicate on activation swaps (of the same activation on two different inputs). From this perspective, you can define join(h1, h2) to be the hypothesis which permits a swap iff h1 and h2 both permit it.**Computation graph centric.**You can equivalently construct the joined hypothesis by the following process. First, ensure that each of the correspondences are bijections, and that both and have the same shape, adding extra no-op nodes as necessary. Now we can define of the joined hypothesis to be the graph where every node contains the tuple of the values from the two earlier interpretations.

The main failure of the algorithm listed above is that we don’t know how to handle cases where the adversary wants to rewrite f to an extensionally-equal function in a way which is mutually incompatible with the original hypothesis (for example, because their computational graphs have different shapes and there’s no way to splice the two computational graphs together). This is a pretty bad problem because the function extensionality move seems very important in practice. ARC has worked on basically this problem for a while and hasn’t yet solved it, regrettably.

Some other questions that we haven’t answered:

- How do we incentivize specific explanations? We don’t know (but haven’t thought about it that much). Our current proposals look something like having a budget for how much hypotheses can reduce entropy.
- The explanations produced by this process will probably by default be impossible for humans to understand; is there some way to fix this? We also don’t have good ideas here. (Note that this isn’t a failure that’s specific to causal scrubbing; it seems fundamentally challenging to generate human-understandable interpretations for complicated superhuman models.) That being said, a lot of our optimism about interpretability comes from applications where the interpretability tools are used by AIs or by human-coded algorithms, rather than by humans, so plausibly we’re fine even if humans can’t understand the interpretability results.

Overall, it seems plausible that these problems can be overcome, but they are definitely not currently solved. We hold out hope for an interpretability process which has validity properties which allow us to use powerful optimization inside it and still trust the conclusions, and hope to see future work in this direction.

^{^}This is also true when you’re training models with an autodiff library–you construct a computational graph that computes loss, and run backprop on the whole thing, which quickly recurses into the model but doesn’t inherently treat it differently.

^{^}This allows for testing out human interpretable approximations to neural network components: ‘

__Artificial Artificial Neural networks__’. We think it’s more informative to see how the model performs with the residual of this approximation resampling ablated as opposed to zero ablated.^{^}In general, you could have the output be non-scalar with any distance metric to evaluate the deviation of the scrubbed expectation, but we’ll keep things simple here.

^{^}Another way of thinking about this is: when we consider the adversarial game setting, we would like each side to be able to request that terms are sampled together. By default therefore we would like terms (even random ones!) to be sampled separately.

Discuss]]>

Summary: This post introduces causal scrubbing, a principled approach for evaluating the quality of mechanistic interpretations. The key idea behind causal scrubbing is to test interpretability hypotheses via *behavior-preserving resampling ablations*. We apply this method to develop a refined understanding of how a small language model implements induction and how an algorithmic model correctly classifies if a sequence of parentheses is balanced.

A question that all mechanistic interpretability work must answer is, “how well does this interpretation explain the phenomenon being studied?”. In the __many__ __recent__ __papers__ __in mechanistic interpretability__, researchers have generally relied on ad-hoc methods to evaluate the quality of interpretations.^{[1]}

This *ad hoc* nature of existing evaluation methods poses a serious challenge for scaling up mechanistic interpretability. Currently, to evaluate the quality of a particular research result, we need to deeply understand both the interpretation and the phenomenon being explained, and then apply researcher judgment. Ideally, we’d like to find the interpretability equivalent of __property-based testing__—automatically checking the correctness of interpretations, instead of relying on grit and researcher judgment. More systematic procedures would also help us scale-up interpretability efforts to larger models, behaviors with subtler effects, and to larger teams of researchers. To help with these efforts, we want a procedure that is both powerful enough to finely distinguish better interpretations from worse ones, and general enough to be applied to complex interpretations.

In this work, we propose **causal scrubbing**, a systematic ablation method for testing precisely stated hypotheses about how a particular neural network^{[2]} implements a behavior on a dataset. Specifically, given an informal hypothesis about which parts of a model implement the intermediate calculations required for a behavior, we convert this to a formal correspondence between a computational graph for the model and a human-interpretable computational graph. Then, causal scrubbing starts from the output and recursively finds all of the invariances of parts of the neural network that are implied by the hypothesis, and then replaces the activations of the neural network with the *maximum entropy*^{[3]} distribution subject to certain natural constraints implied by the hypothesis and the data distribution. We then measure how well the scrubbed model implements the specific behavior.^{[4]} Insofar as the hypothesis explains the behavior on the dataset, the model’s performance should be unchanged.

Unlike previous approaches that were specific to particular applications, causal scrubbing aims to work on a large class of interpretability hypotheses, including almost all hypotheses interpretability researchers propose in practice (that we’re aware of). Because the tests proposed by causal scrubbing are mechanically derived from the proposed hypothesis, causal scrubbing can be incorporated “in the inner loop” of interpretability research. For example, starting from a hypothesis that makes very broad claims about how the model works and thus is consistent with the model’s behavior on the data, we can iteratively make hypotheses that make more specific claims while monitoring how well the new hypotheses explain model behavior. We demonstrate two applications of this approach in later posts: first on a parenthesis balancer checker, then on the induction heads in a two-layer attention-only language model.

We see our contributions as the following:

- We formalize a notion of interpretability hypotheses that can represent a large, natural class of mechanistic interpretations;
- We propose an algorithm,
*causal scrubbing*, that tests hypotheses by systematically replacing activations in all ways that the hypothesis implies should not affect performance. - We demonstrate the practical value of this approach by using it to investigate two interpretability hypotheses for small transformers trained in different domains.

This is the main post in a four post sequence, and covers the most important content:

- What is causal scrubbing? Why do we think it’s more principled than other methods? (sections 2-4)
- A summary of our results from applying causal scrubbing (section 5)
- Discussion: Applications, Limitations, Future work (sections 6 and 7).

In addition, there are three posts with information of less general interest. __The first__ is a series of appendices to the content of this post. Then, a pair of posts covers the details of what we discovered applying causal scrubbing to __a paren-balance checker__ and __induction in a small language model__.^{[5]} They are collected in a sequence here.

**Ablations for Model Interpretability:** One commonly used technique in mechanistic interpretability is the “ablate, then measure” approach. Specifically, for interpretations that aim to explain why the model achieves low loss, it’s standard to remove parts that the interpretation identifies as important and check that model performance suffers, or to remove unimportant parts and check that model performance is unaffected. For example, in __Nanda and Lieberum’s Grokking__ work, to verify the claim that the model uses certain key frequencies to compute the correct answer to modular addition questions, the authors confirm that zero ablating the key frequencies greatly increases loss, while zero ablating random other frequencies has no effect on loss. In __Anthropic’s Induction Head paper__, they remove the induction heads and observe that this reduces the ability of models to perform in-context learning. In the __IOI mechanistic interpretability project,__ the authors define the behavior of a transformer subcircuit by mean-ablating everything except the nodes from the circuit. This is used to formulate criteria for validating that the proposed circuit preserves the behavior they investigate and includes all the redundant nodes performing a similar role.

Causal scrubbing can be thought of as a generalized form of the “ablate, then measure” methodology.^{[6]} However, unlike the standard zero and mean ablations, we ablate modules by resampling activations from *other *inputs (which we’ll justify in the next post). In this work, we also apply causal scrubbing to more precisely measure different mechanisms of induction head behavior than in the Anthropic paper.

**Causal Tracing: **Like causal tracing, causal scrubbing identifies computations by patching activations. However, causal tracing aims to *identify* a specific path (“trace”) that contributes causally to a particular behavior by corrupting all nodes in the neural network with noise and then iteratively denoising nodes. In contrast, causal scrubbing tries to solve a different problem: systematically *testing* hypotheses about the behavior of a whole network by removing (“scrubbing away”) every* *causal relationship that should not matter according to the hypothesis being evaluated. In addition, causal tracing patches with (homoscedastic) Gaussian noise and not with the activations of other samples. Not only does this take your model off distribution, it might have no effect in cases where the scale of the activation is much larger than the scale of the noise.

**Heuristic explanations: **This work takes a perspective on interpretability that is strongly influenced by __ARC__’s __work on “heuristic explanations” of model behavior__. In particular, causal scrubbing can be thought of as a form of __defeasible reasoning__: unlike mathematical proofs (where if you have a proof for a proposition P, you’ll never see a better proof for the negation of P that causes you to overall believe P is false), we expect that in the context of interpretability, we need to accept arguments that might be overturned by future arguments.

We assume a dataset over a domain and a function which captures a behavior of interest. We will then explain the expectation of this function on our dataset, .

This allows us to explain behaviors of the form “a particular model gets low loss on a distribution .” To represent this we include the labels in and both the model and a loss function in :

We also want to explain behaviors such as “if the prompt contains some bigram `AB`

and ends with the token `A`

, then the model is likely to predict `B`

follows next.” We can do this by choosing a dataset where each datum has the prompt `...AB...A`

and expected completion `B`

. For instance:

We then propose a hypothesis about how this behavior is implemented. Formally, a *hypothesis*** ** for is a tuple of three things:

- A computational graph
^{[7]}, which implements the function - We require to be
to (equal on__extensionally equal__*all*of ) - A computational graph , intuitively an ‘interpretation’ of the model.
- A correspondence function from the nodes of to the nodes of .
- We require to be an injective
__graph homomorphism__: that is, if there is an edge in then the edge must exist in .

We additionally require and to each have a single input and output node, where maps input to input and output to output. All input nodes are of type which allows us to evaluate both and on all of .

Here is an example hypothesis:

In this figure, we hypothesize that works by having A compute whether , B compute whether , and then ORing those values. Then we’re asserting that the behavior is explained by the relationship between D and the true label .

A couple of important things to notice:

- We will often rewrite the computational graph of the original model implementation into a more convenient form (for instance splitting up a sum into terms, or grouping together several computations into one).
- You can think of as a heuristic
^{[8]}that the hypothesis claims that the model uses to achieve the behavior. It’s possible that the heuristic is imperfect and will sometimes disagree with the label . In that case our hypothesis would claim that the model should be incorrect on these inputs. - Note that the mapping doesn’t tell you how to translate a value of into an activation, only which nodes correspond.
- We will call the “important nodes” of .
^{[9]}- Let , be nodes in and respectively such that .
- Intuitively this is a claim that when we evaluate both and on the same input, then the value of (usually an activation of the model) ‘represents’ the value of (usually a simple feature of the input).
- The causal scrubbing algorithm will test a weaker claim: that the equivalence classes on inputs to are the same as the equivalence classes on inputs to . We think this is sufficient to meaningfully test the mechanistic interpretability hypotheses we are interested in, although it is not strong enough to eliminate all incorrect hypotheses.

- Let , be nodes in and respectively such that .
- Among other things, the hypothesis claims that nodes of that are not mapped to by are unimportant for the behavior under investigation.
^{[10]}

Hypotheses are covered in more detail in the appendix.

In this section we provide two different explanations of causal scrubbing:

__An informal description__of the activation-replacements that a hypothesis implies are valid. We try to provide a helpful introduction to the core idea of causal scrubbing via many diagrams; and__The causal scrubbing algorithm__and pseudocode

Different readers of this document have found different explanations to be helpful, so we encourage you to skip around or skim some sections.

Our goal will be to define a metric by recursively sampling activations that should be equivalent according to each node of the interpretation . We then compare this value to . If a hypothesis is (reasonably) accurate, then the activation replacements we perform should not alter the loss and so we’d have . Overall, we think that this difference will be a reasonable proxy for the * faithfulness* of the hypothesis—that is, how accurately the hypothesis corresponds to the “real reasons” behind the model behavior.

Consider a hypothesis on the graphs below, where maps to the corresponding nodes of highlighted in green:

This hypothesis claims that the activations A and B respectively represent checking whether the first and second component of the input is greater than 3. Then the activation D represents checking whether either of these conditions were true. Both the third component of the input and the activation of C are unimportant (at least for the behavior we are explaining, the log loss with respect to the label ).

If this hypothesis is true, we should be able to perform two types of ‘resampling ablations’:

- replacing the activations of A, B, and D with the activations on other inputs that are “equivalent” under ; and
- replacing the activations that are claimed to be unimportant for a particular path (such as C or into B) with their activation on any other input.

To illustrate these interventions, we will depict a “treeified” version of where every path from the input to output of is represented by a different copy of the input. Replacing an activation with one from a different input is equivalent to replacing all inputs in the subtree upstream of that activation.

Consider running the model on two inputs _{ }= (5,6,7, True) and _{ }= (8, 0, 4, True). The value of A’ is the same on both and . Thus, if the hypothesis depicted above is correct, the output of A on both these is equivalent. This means when evaluating on we can replace the activation of A with its value on , as depicted here:

To perform the replacement, we replaced all of the inputs upstream of A in our treeified model. (We could have performed this replacement with any other that agrees on A’.)

Our hypothesis permits many other activation replacements. For example, we can perform this replacement for D instead:

The other class of intervention permitted by is replacement of any inputs to nodes in that suggests aren’t semantically important. For example, says that the only important input for A is . So the model’s behavior should be preserved if we replace the activations for and (or, equivalently, change the input that feeds into these activations). The same applies for and into B. Additionally, says that D isn’t influenced by C, so arbitrarily resampling all the inputs to C shouldn’t impact the model’s behavior.

Pictorially, this looks like this:

Notice that we are making 3 different replacements with 3 different inputs simultaneously. Still, if is accurate, we will have preserved the important information and the output of should be similar.

The causal scrubbing algorithm involves performing both of these types of intervention many times. In fact, we want to maximize the number of such interventions we perform on every run of – to the extent permitted by .

We define an algorithm for evaluating hypotheses. This algorithm uses the intuition, illustrated in the previous section, of what activation replacements are permitted by a hypothesis.

The core idea is that hypotheses can be interpreted as an “intervention blacklist”. We like to think of this as the hypothesis sticking its neck out and challenging us to swap around activations in any way that it hasn’t specifically ruled out.

In a single sentence, the algorithm is: Whenever we need to compute an activation, we ask “What are all the other activations that, according to , we could replace this activation with and still preserve the model’s behavior?”, and then make the replacement by choosing uniformly at random from that subset of the dataset, and do this recursively.

In this algorithm we don’t explicitly treeify G; but we traverse it one path at a time in a tree-like fashion.

We define the * scrubbed expectation*, , as the expectation of the behavior over samples from this algorithm.

*(This is mostly redundant with the pseudocode below. Read in your preferred order.)*

The algorithm is defined in pseudocode below. Intuitively we:

- Sample a random reference input from
- Traverse all paths through from output towards the input by calling
`run_scrub`

on nodes of recursively. For every node we consider the subgraph of that contains everything ‘upstream’ of (used to calculate its value from the input). Each of these correspond to a subgraph of the image in . - The return value of
`run_scrub(n_I, c, D, x)`

is an activation from . Specifically it is an activation for the corresponding node in that the**hypothesis claims represents the value of**when is run on input`x`

.- Let .
- If is an input node we will return .
- Otherwise we will determine the activations of each input from the parents of . For each parent of :

- If there exists a parent of that corresponds to then the hypothesis claims that the value of is important for . In particular it is important as it represents the value defined by . Thus we sample a datum
`new_x`

that agrees with on the value of . We’ll**recursively call**`run_scrub`

on in order to get an activation for . - For any “unimportant parent” not mapped by the correspondence, we select an input
`other_x`

. This is a random input from the dataset, however we enforce that the*same*random input is used by all unimportant parents of a particular node.^{[12]}We record the value of on`other_x`

. - We now have the activations of all the parents of – these are exactly the inputs to running the function defined for the node . We return the output of this function.

```
def estim(h, D):
"""Estimate E_scrubbed(h, D)"""
_G, I, c = h
outs = []
for i in NUM_SAMPLES:
x = random.sample(D)
outs.append(run_scrub(c, D, output_node_of(I), x))
return mean(outs)
def run_scrub(
c, # correspondence I -> G
D: Set[Datum],
n_I, # node of I
ref_x: Datum
):
"""Returns an activation of n_G which h claims represents n_I(ref_x)."""
n_G = c(n_I)
if n_G is an input node:
return ref_x
inputs_G = {}
# pick a random datum to use for all “unimportant parents” of this node
random_x = random.sample(D)
# get the scrubbed activations of the inputs to n_G
for parent_G in n_G.parents():
# “important” parents
if parent_G is in map(c, n_I.parents()):
parent_I = c.inverse(parent_G)
# sample a new datum that agrees on the interpretation node
new_x = sample_agreeing_x(D, parent_I, ref_x)
# and get its scrubbed activations recursively
inputs_G[parent_G] = run_scrub(c, D, parent_I, new_x)
# “unimportant” parents
else:
# get the activations on the random input value chosen above
inputs_G[parent_G] = parent_G.value_on(random_x)
# now run n_G given the computed input activations
return n_G.value_from_inputs(inputs_G)
def sample_agreeing_x(D, n_I, ref_x):
"""Returns a random element of D that agrees with ref_x on the value of n_I"""
D_agree = [x in D if n_I.value_on(ref_x) == n_I.value_on(x)]
return random.sample(D_agree)
```

Suppose a hypothesis claims that some module in the model isn’t important for a given behavior. There are a variety of different interventions that people do to test this. For example:

- Zero ablation: setting the activations of that module to 0
- Mean ablation: replacing the activations of that module with their empirical mean on D
- Resampling ablation: patching in the activation of that module on a random different input

In order to decide between these, we should think about the precise claim we’re trying to test by ablating the module.

If the claim is “this module’s activations are literally unused”, then we could try replacing them with huge numbers or even NaN. But in actual cases, this would destroy the model behavior, and so this isn’t the claim we’re trying to test.

We think a better type of claim is: “The behavior might depend on various properties of the activations of this module, but those activations aren’t encoding any information that’s relevant to this subtask.” Phrased differently: The distribution of activations of this module is (maybe) important for the behavior. But we don’t depend on any properties of this distribution that are conditional on *which* particular input the model receives.

This is why, in our opinion, the most direct way to translate this hypothesis into an intervention experiment is to patch in the module’s activation on a randomly sampled different input–this distribution will have all the properties that the module’s activations usually have, but any connection between those properties and the correct prediction will have been scrubbed away.

Despite their prevalence in prior work, zero and mean ablations do not translate the claims we’d like to make faithfully.

As noted above, the claim we’re trying to evaluate is that the information in the output of this component doesn’t matter for our current model, not the claim that deleting the component would have no effect on behavior. We care about evaluating the claim as faithfully as possible on our current model and not replacing it with a slightly different model, which zero or mean ablation of a component does. This core problem can manifest in three ways:

*Zero and mean ablations take your model off distribution in an unprincipled manner.**Zero and mean ablations can have unpredictable effects on measured performance.**Zero and mean ablations remove variation and thus present an inaccurate view of what’s happening.*

For more detail on these specific issues, we refer readers to the __appendix post.__

To show the value of this approach, we apply causal scrubbing algorithm to two tasks: 1) verifying hypotheses about an algorithmic model we found previously through ad-hoc interpretability, and 2) test and incrementally improve hypotheses about how induction heads work on a 2-layer attention only model. Here, we summarize the results of those applications here to illustrate the applications of causal scrubbing; detailed results can be found in the respective auxiliary posts.

We apply the causal scrubbing algorithm to a small transformer which classifies sequences of parentheses as balanced or unbalanced; see the __results post__ for more information. In particular, we test three claims about the mechanisms this model uses.

**Claim 1: **There are three heads that directly pass important information to output:^{[13]}

- Heads 1.0 and 2.0 test the conjunction of two checks: that there are an equal number of open and close parentheses in the entire sequence, and that the sequence starts open.
- Head 2.1 checks that the nesting depth is never negative at any point in the sequence.

Claim 1 is represented by the following hypothesis:^{[14]}

**Claim 2: **Heads 1.0 and 2.0 depend only on their input at position 1, and this input indirectly depends on:

- The output of 0.0 at position 1, which computes the overall proportion of parentheses which are open. This is written into a particular direction of the residual stream in a linear fashion.
- The embedding at position 1, which indicates if the sequence starts with
`(`

.

**Claim 3: **Head 2.1 depends on the input at all positions, and if the nesting depth (when reading right to left!) is negative at that position.^{[15]}

Here is a visual representation of the combination of all three claims:

Testing these claims with causal scrubbing, we find that they are reasonably, but not completely, accurate:

Claim(s) tested | Performance recovered^{[16]} |

1 | 93% |

1 + 2 | 88% |

1 + 3 | 84% |

1 + 2 + 3 | 72% |

As expected, performance drops as we are more specific about how exactly the high level features are computed. This is because as the hypotheses get more specific, they induce more activation replacements, often stacked several layers deep.^{[17]}

This indicates our hypothesis is subtly incorrect in several ways, either by missing pathways along which information travels or imperfectly identifying the features that the model uses in practice.

We explain these results in more detail in __this appendix post__.

We investigated ‘induction’ heads in a 2 layer attention only model. We were able to easily test out and incrementally improve hypotheses about which computations in the model were important for the behavior of the heads.

We first tested a naive induction hypothesis, which separates out the input to an induction head in layer 1 into three separate paths – the value, the key, and the query – and specified where the important information in each path comes from. We hypothesized that both the values and queries are formed based on only the input directly from the token embeddings via the residual stream and have no dependence on attention layer 0. The keys, however, are produced only by the input from attention layer 0; in particular, they depend on the part of the output of attention layer 0 that corresponds to attention on the previous token position.^{[18]}

We test these hypotheses on a subset of openwebtext where induction is likely (but not guaranteed) to be helpful.^{[19]} Evaluated on this dataset, this naive hypothesis only recovers 35% of the performance. In order to improve this we made various edits which allow the information to flow through additional pathways:

- First, we allow the attention pattern of the induction head to compare a set of three consecutive tokens (instead of just a single token) to determine when to induct.
- Next, we also allow the query and value to also depend on the part of the output of layer 0 that corresponds to the current position.
- We also special case three layer 0 heads which attend to repeated occurrences of the current token. In particular, we assume that the important part of the output of these heads is what their output would be
*if*their attention was just an identity matrix.^{[20]}

With these adjustments, our hypothesis recovers 86% of the performance.

We believe it would have been significantly harder to develop and have confidence in a hypothesis this precise only using ad-hoc methods to verify the correctness of a hypothesis.

We explain these results in more detail in __this appendix post__.

The most obvious application of causal scrubbing to alignment is using it to evaluate mechanistic interpretations. In particular, we can imagine several specific use cases that are relevant to alignment:

*Checking interpretations of model behaviors produced by human researchers.*Having a standardized, reliable, and convenient set of tests would make it much easier to scale up mechanistic interpretability efforts; this might be particularly important if there are big interpretability projects right before the deployment of transformative AI.*Automated algorithmic searches for explanations.*In some cases, researchers might be able to specify a space of hypotheses and then use optimization algorithms to find the most predictive ones. We’ve done some work like this and we hope to do much more in the future.*AI-assisted explanations.*We might be able to train models to produce highly rated and human-understandable explanations.

In all three applications, we required that researchers understand the explanations that were verified by causal scrubbing. Unfortunately, it might be the case that the behaviors we want to interpret in large neural networks won’t have *any* understandable interpretations at all if most of the cognition performed inside powerful AI systems is in some sense irreducibly complex. It also seems plausible that even if these human-understandable interpretations exist, it might be intractable or impractical to find them.

A lot of our interest in causal scrubbing (and mechanistic interpretability more generally) comes from applications which require interpretability-like techniques which rely on formally manipulating explanation-like objects but *don’t* require that these objects be understood by anyone (human or AI):

*Automated strategies for solving ELK.*__ARC__is optimistic about__some strategies__for solving__ELK__that involve searching for objects similar to causal scrubbing explanations and then using properties of these explanations as part of the training procedure of the model, in ways that don’t require humans to understand the explanations.*Detecting deceptive alignment.*Suppose you have a weak trusted model and a strong untrusted model. You might be able to search for explanations of why these models take similar actions which allow you to distinguish whether the untrusted model is deceptively aligned just based on the structure of the explanation, rather than via having to understand its content.requires some way of adjudicating arguments about whether the internals of models imply they’ll behave badly in ways that are hard to find with random sampling (because the failures only occur off the training distribution, or they’re very rare). This doesn’t require that any human is able to understand these arguments; it just requires we have a mechanical argument evaluation procedure. Improved versions of the causal scrubbing algorithm might be able to fill this gap.__Relaxed adversarial training__

Unfortunately, causal scrubbing may not be able to express all the tests of interpretability hypotheses we might want to express:

- Causal scrubbing only allows activation replacements that are
*perfectly permissible*by the hypothesis: that is, the respective inputs have an exactly equal value in the correspondance. - Despite being maximally strict in what replacements to allow, we are in practice willing to accept hypotheses that fail to perfectly preserve performance. We think this is an inconsistency in our current approach.
- As a concrete example, if you think a component of your model encodes a continuous feature, you might want to test this by replacing the activation of this component with the activation on an input that is
*approximately*equal on this feature–causal scrubbing will refuse to do this swap. - You can solve this problem by considering a generalized form of causal scrubbing, where hypotheses specify a non-uniform distribution over swaps. We’ve worked with this “generalized causal scrubbing” algorithm a bit. The space of hypotheses is continuous, which is nice for a lot of reasons (e.g. you can search over the hypothesis space with SGD). However, there are a variety of conceptual problems that still need to be resolved (e.g. there are a few different options for defining the union of two hypotheses, and it’s not obvious which is most principled).
- Causal scrubbing can only propose tests that can be constructed using the data provided to it. If your hypothesis predicts that model performance will be preserved if you swap the input to any other input which has a particular property, but no other inputs in the dataset have that property, causal scrubbing can’t test your hypothesis. This happens in practice–there is probably only one sequence in webtext with a particular first name at token positions 12, 45, and 317, and a particular last name at 13, 46, 234.
- This problem is addressed if you are able to produce samples that match properties by some mechanism other than rejection sampling.
- Causal scrubbing doesn’t allow us to distinguish between two features that are perfectly correlated on our dataset, since they would induce the same equivalence classes. In fact, to the extent that two features A and B are highly correlated, causal scrubbing will not complain if you misidentify an A-detector as a B-detector.
^{[21]}

Another limitation is that casual scrubbing does not guarantee that it will reject a hypothesis that is importantly false or incomplete. Here are two concrete cases where this happens:

- When a model uses some heuristic that isn’t
*always*applicable, it might use other circuits to inhibit the heuristic (for example, the negative name mover heads in the__Indirect Object Identification paper__). However, these inhibitory circuits are purely harmful for inputs where the heuristic*is*applicable. In these cases, if you ignore the inhibitory circuits, you might overestimate the contribution of the heuristic to performance, leading you to falsely believe that your incomplete interpretation fully explains the behavior (and therefore fail to notice other components of the network that contribute to performance). - If two terms are correlated, sampling them independently (by two different random activation swaps) reduces the variance of the sum. Sometimes, this variance can be harmful for model performance – for instance, if it represents
__interference from polysemanticity__. This can cause a hypothesis that scrubs out correlations present in the model’s activations to appear ‘more accurate’ under causal scrubbing.^{[22]}

These examples are both due to the hypotheses not being specific *enough* and neglecting to include some correlation in the model (either between input-feature and activation or between two activations) that would hurt the performance of the scrubbed model.

We don’t think that this is a problem with causal scrubbing in particular; but instead is because interpretability explanations should be regarded as an example of __defeasible reasoning__, where it is possible for an argument to be overturned by further arguments.

We think these problems are fairly likely to be solvable using an adversarial process where hypotheses are tested by allowing an adversary to modify the hypothesis to make it more specific in whatever ways affect the scrubbed behavior the most. Intuitively, this adversarial process requires that proposed hypotheses “point out all the mechanisms that are going on that matter for the behavior”, because if the proposed hypothesis doesn’t point something important out, the adversary can point it out. More details on this approach are included in the __appendix post__.

Despite these limitations, we are still excited about causal scrubbing. We’ve been able to directly apply it to understanding the behaviors of simple models and are optimistic about it being scalable to larger models and more complex behaviors (insofar as mechanistic interpretability can be applied to such problems at all). We currently expect causal scrubbing to be a big part of the methodology we use when doing mechanistic interpretability work in the future.

*This work was done by the Redwood Research interpretability team. We’re especially thankful for Tao Lin for writing the software that we used for this research and for Kshitij Sachan for contributing to early versions of causal scrubbing. Causal scrubbing was strongly inspired by Kevin Wang, Arthur Conmy, and Alexandre Variengien’s **work on how GPT-2 Implements Indirect Object Identification**. We’d also like to thank Paul Christiano and Mark Xu for their insights on heuristic arguments on neural networks. Finally, thanks to Ben Toner, Oliver Habryka, Ajeya Cotra, Vladimir Mikulik, Tristan Hume, Jacob Steinhardt, Neel Nanda, Stephen Casper, and many others for their feedback on this work and prior drafts of this sequence.*

^{^}For example, in

__the causal tracing paper__(Meng et al 2022), to evaluate whether their hypothesis correctly identified the location of facts in GPT-2, the authors replace the activation of the involved neurons and observed that the model behaved as though it believed the edited fact, and not the original fact. In__the Induction Heads paper__(Olsson et al 2022) the authors provide six different lines of evidence, from macroscopic co-occurrence to mechanistic plausibility.^{^}Causal scrubbing is technically formulated in terms of general computational graphs, but we’re primarily interested in using causal scrubbing on computational graphs that implement neural networks.

^{^}See the discussion in the “An alternative formalism: constructing a distribution on treeified inputs” section of

__the appendix post__.^{^}Most commonly, the behavior we attempt to explain is why a model achieves low loss on a particular set of examples, and thus we measure the loss directly. However, the method can explain any expected quality of the model’s output.

^{^}We expect the results posts will be especially useful for people who wish to apply causal scrubbing in their own research.

^{^}Note that we can use causal scrubbing to ablate a particular module, by using a hypothesis where that specific module’s outputs do not matter for the model’s performance.

^{^}A computational graph is a graph where the nodes represent computations and the edges specify the inputs to the computations.

^{^}In the normal sense of the word, not ARC’s

__Heuristic Arguments____approach__^{^}Since is required to be an injective graph homomorphism, it immediately follows that is a subgraph of which is isomorphic to . This subgraph will be a union of paths from the input to the output.

^{^}In the appendix we’ll discuss that it is

__possible to modify__the correspondence to include these unimportant nodes, and that doing so removes some__ambiguity__on when to sample unimportant nodes together or separately.^{^}We have no guarantee, however, that any hypothesis that passes the causal scrubbing test is desirable. See more discussion of counterexamples in the limitations section.

^{^}This is because otherwise our algorithm would crucially depend on the exact representation of the causal graph: e.g. if the output of a particular attention layer was represented as a single input or if there was one input per attention head instead. There are several other approaches that can be taken to addressing this ambiguity, see the

__appendix__.^{^}That is, we consider the contribution of these heads through the residual stream into the final layer norm, excluding influence they may have through intermediate layers.

^{^}Note that as part of this hypothesis we have aggressively simplified the original model into a computational graph with only 5 separate computations. In particular, we relied on the fact that residual stream just before the classifier head can be written as a sum of terms, including a term for each attention head (see “

__Attention Heads are Independent and Additive__” section of Anthropic’s “Mathematical Framework for Transformer Circuits” paper). Since we claim only three of these terms are important, we clump all other terms together into one node. Additionally note this means that the ‘Head 2.0’ node in G includes*all*of the computations from layers 0 and 1, as these are required to compute the output of head 2.0 from the input.^{^}The claim we test is

__somewhat more subtle__, involving a weighted average between the proportion of the open-parentheses in the prefix and suffix of the string when split at every position. This is equivalent for the final computation of balancedness, but more closely matches the model’s internal computation.^{^}As measured by normalizing the loss so 100% is loss of the normal model (0.0003) and 0% is the loss when randomly interchanging the labels (4.3). For the reasoning behind this metric see the appendix.

^{^}Our final hypothesis combines up to 51 different inputs: 4 inputs feeding into each of 1.0 and 2.0, 42 feeding into 2.1 (one for each sequence position), and 1 for the ‘other terms’.

^{^}The output of an attention layer can be written as a sum of terms, one for each previous sequence position. We can thus claim that only one of these terms is important for forming the queries.

^{^}In particular we create a whitelist of tokens on which exact 2-token induction is often a helpful heuristic (over and above bigram-heuristics). We then filter openwebtext (prompt, next-token) pairs for prompts that end in tokens on our whitelist. We evaluate loss on the actual next token from the dataset, however, which may not be what induction expects. More details here.

We do this as we want to understand not just how our model implements induction but also how it decides*when*to use induction.^{^}And thus the residual of (actual output - estimated output) is unimportant and can be interchanged with the residual on any other input.

^{^}This is a common way for interpretability hypotheses to be ‘partially correct.’ Depending on the type of reliability needed, this can be more or less problematic.

^{^}Another real world example of this is this

__this experiment__on the paren balance checker

Discuss]]>

*As a writing exercise, I'm writing an AI Alignment Hot Take Advent Calendar - one new hot take, written every day for 25 days. Or until I run out of hot takes, which seems likely.*

This was waiting around in the middle of my hot-takes.txt file, but it's gotten bumped up because of Rob and Eliezer - I've gotta blurt it out now or I'll probably be even more out of date.

The idea of using AI research to help us be better at building AI is not a new or rare idea. It dates back to prehistory, but some more recent proponents include OpenAI members (e.g. Jan Leike) and the Accelerating Alignment group. We've got a tag for it. Heck, this even got mentioned yesterday!

So a lot of this hot take is really about my own psychology. For a long time, I felt that sure, building tools to help you build friendly AI was possible in principle, but it wouldn't *really* help. Surely it would be faster just to cut out the middleman and understand what we want from AI using our own brains.

If I'd turned on my imagination, rather than reacting to specific impractical proposals that were around at the time, I could have figured out how augmenting alignment research is a genuine possibility a lot sooner, and started considering the strategic implications.

Part of the issue is that plausible research-amplifiers don't really look like the picture I have in your head of AGI - they're not goal-directed agents who want to help us solve alignment. If we could build those and trust them, we really *should *just cut out the middleman. Instead, they can look like babble generators, souped-up autocomplete, smart literature search, code assistants, and similar. Despite either being simulators or making plans only in a toy model of the world, such AI really does have the potential to transform intellectual work, and I think it makes a lot of sense for there to be some people doing work to make these tools differentially get applied to alignment research.

Which brings us to the dual-use problem.

It turns out that other people would *also* like to use souped-up autocomplete, smart literature search, code assistants, and similar. They have the potential to transform intellectual work! Pushing forward the state of the art on these tools lets you get them earlier, yet it also helps other people get them earlier too, even if you don't share your weights.

Now, maybe the most popular tools will help people make philosophical progress, and accelerating development of research-amplifying tools will usher in a brief pre-singularity era of enlightenment. But - lukewarm take - that seems way less likely than such tools differentially favoring engineering over philosophy on a society-wide scale, making everything happen faster and be harder to react to.

So best of luck to those trying to accelerate alignment research, and fingers crossed for getting the differential progress right, rather than oops capabilities.

Discuss]]>

*This is an entry in the 'Dungeons & Data Science' series, a set of puzzles where players are given a dataset to analyze and an objective to pursue using information from that dataset. *

You were saddled with debt, and despair, and regret;

But you left it behind to embark,

With a visiting ship who were planning a trip,

Hunting some strange sea-beasts they call . . . “Snark”?

(After climbing aboard and departing the shore,

Your life is if anything worse.

The grog makes you groggy; the sea makes you soggy;

The songs leave you thinking in verse.)

Snark-hunting, you find, is a peaceful pastime.

By now, every crew knows the way,

To - with ease! - guarantee their success and safety,

As they seek, and they lure, and they slay.

A single exception proves the above rule:

While with *most *Snarks, Snark-hunting’s a breeze,

If your Snark is a Boojum, it hunts you right back,

And slays you with similar ease.

When you learn the detail that a Snark-hunt can fail,

You speak with the crew and insist,

On wielding Science skills to discern which Snarks kill,

And becoming the ship’s Analyst.

(The Butcher and Baker and Barnacle-Scraper,

All tell you that simply won’t do.

“You’ll be telling us what is a Boojum, what’s not;

Thus, you’re our **Boojumologist**; true?”

When you meekly protest you’ve not earned that address,

They rebut: “As you clearly percieve,

The Bellman can’t ring, the Belter can’t sing,

And the Beaver knows not how to Beave.”)

Notwithstanding contention *viz* naming conventions,

Your task is as clear as could be:

Using data from trips made by similar ships,

Choose which Snarks to fight, which to flee.

- You have a dataset of Snark sightings. This tells you whether a sighted Snark was hunted and whether – if hunted – it turned out to be a Boojum. (To clarify, the order of operations is “someone sights a Snark” “they take notes on its behavior and relay this information to the nearest Snark-hunting ship” “the ship’s crew decide whether to hunt it”).
- You have a list of Snarks your ship can choose to go after. The ship must hunt six of them for the venture to be profitable; the crew won’t accept a plan targeting fewer than six. If your ship hunts a Boojum, everyone aboard vanishes, including you.
- Your objective can be to choose the six Snarks which maximize probability of survival, maximize the EV of the number of Snark bodies in your possession at the end of the trip (i.e. max([Snarks hunted] * P(Survival))), or do anything inbetween.
- (You could also optimize for killing the crew as quickly and reliably as possible, if you like. Or you could ignore the 'challenge' part of the challenge entirely, and just try to characterize the dataset. I’m not the boss of you.)

If you post a selection of Snarks before the deadline, and your selection is on the efficient frontier – that is, if you’re the first to submit that selection, and if no other submission has both better EV *and* better survival probability – you will be entitled to receive [UNSPECIFIED BENEFIT (UNDERWHELMING)] at [UNSPECIFIED TIME (DISTANT)]. Every player can make one submission; if you post multiple plans, please specify clearly which of these is for the Bonus Task.

I’ll post an interactive you can use to test your choices, along with an explanation of how I generated the dataset, sometime on Monday the 12th. I’m giving you nine days, but the task shouldn’t take more than an evening or two; use Excel, R, Python, recollections from a previous timeloop, or whatever other tools you think are appropriate. Let me know in the comments if you have any questions about the scenario.

If you want to investigate collaboratively and/or call your decisions in advance, feel free to do so in the comments; however, please use spoiler tags or rot13 when sharing inferences/strategies/decisions, so people intending to fly solo can look for clarifications without being spoiled.

Discuss]]>

Interpretability techniques often need to throw away some information about a neural network's computations: the entirety of the computational graph might just be too big to understand, which is part of why we need interpretability in the first place. In this post, I want to talk about two different ways of simplifying a network's computational graph:

- Fully explaining parts of the computations the network performs (e.g. identifying a subcircuit that fully explains a specific behavior we observed)
- Approximately describing how the entire network works (e.g. finding meaningful modules in the network, whose internals we still don't understand, but that interact in simple ways)

These correspond to the idea of subsets and quotients in math, as well as many other instances of this duality in other areas. I think lots of interpretability at the moment is 1., and I'd be excited to see more of 2. as well, especially because I think there are synergies between the two.

The entire post is rather hand-wavy; I'm hoping to point at an intuition rather than formalize anything (that's planned for future posts). Note that a distinction like the one I'm making isn't new (e.g. it's intuitively clear that circuits-style research is quite different from neural network clusterability). But I haven't seen it described this explicitly before, and I think it's a useful framing to keep in mind, especially when thinking about how different interpretability techniques might combine to yield an overall understanding.

ETA: An important clarification is that for both 1. and 2., I'm only discussing interpretability techniques that try to understand the *internal structure* of a network. In particular, 2. talks about approximate descriptions of the algorithm the network is *actually* using, not just approximate descriptions of the function that's being implemented. This excludes large parts of interpretability outside AI existential safety (e.g. any method that treats the network as a black box and just fits a simpler function to the network).

In math, if you have a set , there are two "dual" ways to turn this into a smaller set:

- You can take a
*subset*. - You can take a
*quotient*of by some equivalence relation .

(The quotient is the set of all equivalence classes under .)

I think this is also a good framing to distinguish interpretability techniques, but before elaborating on that, I want to build intuitions for subsets and quotients in contexts other than interpretability.

- Maps
*into*induce subsets of (namely their image ). For this subset, it doesn't matter how many elements in where mapped to a given element in , so we can assume is injective without loss of generality. Thus, subsets are related to*injective maps*. Dually, maps*out of*induce quotients of : we can define two elements in to be equivalent if they map to the same element in . The quotient is then the set of all preimages for . Again, we can assume that is surjective if we only care about the quotient itself, so quotients correspond to*surjective maps*. - A chapter of a book is a subset of its text. A summary of the book is like a quotient. Note that both throw away information, and that for both, there will be many different possible books we can't distinguish with only the subset/quotient. But they're very different in terms of
*which*information they throw away and which books become indistinguishable. Knowing only the first chapter leaves the rest of the book entirely unspecified, but that one chapter is nailed down exactly. Knowing a summary restricts choices for the entire book somewhat, but leaves local freedom about word choice etc. everywhere. - If I have a dataset, then samples from that dataset are like subsets, summary statistics are like quotients. Again, both throw away information, but in very different ways.
- If I want to communicate to you what some word means, say "plant" then I can either give examples of plants (subset), or I can describe properties that plants have (quotient).

The subset/quotient framework can be applied to mechanistic interpretability as follows: fully explaining part of a network is analogous to subsets, abstracting the entire network and describing how it works at a high level is analogous to quotients. Both of these are ways of simplifying a network that would otherwise be unwieldy to work with, but again, they simplify in quite different ways.

These subsets/quotients of the *mechanism/computation* of the network seem to somewhat correspond to subsets/quotients of the *behavior* of the network:

- If we interpret a subset of the neurons/weights/... of the network in detail, that subset is often chosen to explain the network's behavior on a subset of inputs very well (while we won't get much insight into what happens on other inputs).
- A rough high-level description of the network could plausibly be similarly useful to predict behavior on many different inputs. But it won't let us predict behavior
*exactly*on any of them—we can only predict certain properties the outputs are going to have. So this leads to a quotient on outputs.

To make this a bit more formal: if we have a network that implements some function , then simplifying that network using interpretability tools might give us two different types of simpler functions:

- A restriction of to some subset
- A composition of with a quotient map on , i.e. a function

Interpreting part of the network seems related to the first of these, while abstracting the network to a high-level description seems related to the second one. For now, this is mostly a vague intuition, rather than a formal claim (and there are probably exceptions, for example looking at some random subset of neurons might just give us no predictive power at all).

I'll go through some examples of interpretability research and describe how I think they fit into the subset/quotient framework:

- The work on indirect object identification in GPT-2 small is a typical example of the subset approach: it explains GPT-2's behavior on a very specific subset of inputs, by analyzing a subset of its circuits in a lot of detail.
- Induction heads are similar in that they still focus on precisely understanding a small part of the network. However, they help understand behavior on a somewhat broader range of inputs, and they aren't specific to one model in particular (which is a dimension I've ignored in this post).
- The analysis of early vision in InceptionV1 has some aspects that feel quotient-y (namely grouping neurons by functionality), but it focuses entirely on the subset of early layers and mostly explains what individual neurons do. Overall, I'd put this mostly in the subset camp.
- The general idea that early layers of a CNN tend to detect low-level features like curves, which are then used to compute more complicated features, which are finally turned into an output label, is a clear example of a quotient explanation of how these models work. This is also a good example of how the approaches can interact: studying individual neurons can give strong evidence that this quotient explanation is correct.
- Clusterability in neural networks and other work on modularity are other typical examples of quotient approaches to interpretability.
- Acquisition of chess knowledge in AlphaZero combines elements of a subset and a quotient approach. Figuring out that AlphaZero represents lots of human chess concepts is part of a quotient explanation: it lets us explain at a very high level of abstraction how AlphaZero evaluates positions (presumably by using those concepts, e.g. recognizing that a position where you have an unsafe king is bad). On the other hand, the paper certainly doesn't provide a
*complete*picture of how AlphaZero thinks, not even at such a high level of abstraction (e.g. it's unclear how these concepts are actually being used, we can only make reasonable guesses as to what a full explanation at this level would look like). - The reverse-engineered algorithm for modular addition seems to me to be an example of a subset-based approach (i.e. my impression is that the algorithm was discovered by looking at various parts of the network and piecing together what was happening). The unusual thing about it is that the "subset" being explained is ~everything the network does. So you could just as well think of the end product as a quotient explanation (at a rather fine-grained level of abstraction). This is an example of how both approaches converge as the subset increases in size and the abstraction level becomes more and more fine-grained.
- The polytope lens itself feels like a quotient technique (reframing the computations a network is doing at a specific level of abstraction, talking about subunits as groups of polytopes with similar spline codes). However, given that it abstracts a network at a very fine-grained level, I'd expect it to be combined with subset approaches in practice. Similar things apply to the mathematical transformer circuits framework.
- Causal scrubbing focuses on testing subset explanations: it assumes that a hypothesis is an embedding of a smaller computational graph into the larger one.
^{[1]}

I've already mentioned two examples of how both types of techniques can work in tandem:

- A subset analysis can be used to test a quotient explanation (e.g. if I conjecture that early CNN layers detect low-level features like curves, that are then combined to compute increasingly high-level concepts like dog ears, I can test that by looking at a bunch of example neurons)
- Good fine-grained quotients can make it easier to explain subsets of a network (e.g. the polytope lens, the mathematical transformers framework, or other abstractions that are easier to work with then thinking about literal multiplications and additions of weights and activations).

Some more hypothetical examples:

- Understanding a network in terms of submodules might point us to interesting subsets to study in detail. For example, a submodule that reasons about human psychology might be more important to study than one that does simple perception tasks.
- A high-level understanding of a network should make it easier to understand low-level details in subsets. E.g. if I suspect that the neurons I'm looking at are part of a submodule that somehow implements a learned tree search, it will be much easier to figure out
*how*the implementation works than if I'm going in blind. - Conversely, subset-based techniques might be helpful for identifying submodules and their functions. If I figure out what a specific neuron or small group of neurons is doing, that puts restrictions on what the high-level structure of the network can be.
- We can try to first divide a network into submodules and then understand each of them using a circuits-style approach. Combining the abstraction step with the low-level interpretation lets us
*parallelize*understanding the network. Without the initial step of finding submodules, it might be very difficult to split up the work of understanding the network between lots of researchers.

I'm pretty convinced that combining these approaches is more fruitful than either one on its own, and my guess is that this isn't a particularly controversial take. At the same time, my sense is that most interpretability research at the moment is closer to the "subset" camp, except for frameworks like transformer circuits that are about very *fine-grained* quotients (and thus mainly tools to enable better subset-based research). The only work I'm aware of that I would consider clear examples of quotient research at a high level of abstraction are Daniel Filan's Clusterability in neural networks line of research and some work on modularity by John Wentworth's SERI MATS cohort.

Some guesses as to what's going on:

- I missed a bunch of work in the quotient approach.
- People think the subset approach is more promising/we don't need more research on submodules/...
- Subset-style research is currently quite tractable using empirical approaches and easier to scale, whereas quotient-style research needs more insights that are hard to find.
- Maybe the framework I'm using here is just confused? But even then, I'd still think that the "finding high-level structure in neural networks" is clearly a sensible distinct category, and neglected compared to circuits-style work.

I'd be very curious to hear to hear your thoughts (especially from people working on interpretability: why did you pick the specific approach you're using?)

^{^}A way to fit in quotient explanations would be to make the larger graph itself a quotient of the neural network, i.e. have its nodes perform complex computations. But causal scrubbing doesn't really discuss what makes such a quotient explanation a good one (except for extensional equality).

Discuss]]>

This summer I learned about the concept of *Audience Capture *from the case of Nicholas Perry. Through pure force of social validation, he experienced a shift from an idealistic but obscure young man to a grotesque but popular caricature of a medical train wreck.

The change happened through social reward signals. Originally Nicholas the principled vegan made videos of himself playing the violin, much to no one's interest. The earnest young man then learned he had to give up his vegan diet for health reasons, and thought to give the occasion a positive twist by inviting his viewers to share the first meal of his new lifestyle.

It was an innocuous step. He gained viewers. They cheered him on to eat more. And he did.

Gradually, but steadily he ate and ate, to the cheers of a swelling crowd of online followers. And like the Ghandi Murder Pill, the choice of sacrificing a sliver of his values for substantial reward was worth it for each individual video he made. His popularity expanded with his waistline as he inched up the social incentive slope. And at the end of that slope Nicholas didn't care about health, veganism, or playing the violin anymore. Instead his brain was inured with social reward signals that had rewired his values on a fundamental level. Essentially, Nicholas had a become a different person.

Now I realize I am unlikely to gain 300 pounds from success on AI alignment articles on LessWrong, but *audience capture* does point to a worry that has been on my mind. How do new researchers in the field keep themselves from following social incentive gradients? Especially considering how hard it is to notice such gradients in the first place!

Luckily the author of the above article suggests a method to ward against *audience capture*: Define your ideal self up front and commit to aligning your future behavior with her. So this is what I am doing here -- I want to precommit to three *Research Principles* for the next 6 months of my AI alignment studies:

**Transparency**- I commit to exposing my work in progress, unflattering confusions of thinking, and potentially controversial epistemics.**Exploration**- I commit to exploring new paths for researching the alignment problem and documenting my progress along the way.**Paradigmicity**- I commit to working toward a coherent paradigm of AI alignment in which I can situate my work, explain how it contributes to solving alignment, and measure my progress toward this goal.

Let's take a closer look at each research principle.

First things first.

*I've received a research grant to study AI alignment and I don't know if I'm the right person for it.*

This admission is not lack of confidence or motivation. I feel highly driven to work on the problem, and I know what skills I have that qualify me for the job. However, AI alignment is a new field and it's unclear what properties breakthrough researchers have. So naturally, I can't assess if I have these unknown properties until I actually make progress on the problem.

Still the admission feels highly uncomfortable -- like I'm breaking a rule. It feels like the type of thing where common wisdom would tell me to power pose myself out of this frame of mind. I think that wisdom is wrong. What I want is for alignment to get solved, which means I want the right people working on it.

Is the "right people" me?

I don't know. *But I also think most people can't know*. I think self-assessment is a trap due to motivated reasoning and other biases I'm still learning about. Instead, I believe it's better to commit to transcribing one's work. It can speak for itself. Thus I won't hype myself or sell myself or only show the smartest things I came up with. In short, I want to commit to a form of transparency based on epistemic humility.

This approach will obviously lead to more clutter compared to filtering one's output on quality. Still, I'd argue the trade-off is worth it because it allows evaluation of on-going work instead of an implicit competition of self-assessment smothered in social skills and perception management.

*Thus I commit to exposing my work in progress, unflattering confusions of thinking, and potentially controversial epistemics.*

Restrospectives don't capture the actual experience of going through a research process. Our memories are selective and our narratives are biased. Journaling our progress as we go avoids these failings but such journals will be rife with dead ends and deprived of hindsight wisdom. On the other hand, if the useful information density is too low, people can simply opt out of reading them.

Win-win.

So what outcomes should we expect? I think there are four possible research results of the next few months: A huge breakthrough no one had considered, a useful research direction that more people are already working on, a useless path no one has explored before, and a useless approach that was predictably useless. Thus we have a 2x2 grid of outcomes across the Useful-Useless axis and the Known-Unknown axis:

Useful | Useless | |

Known | Converge to existing path | Converge to existing dead ends |

Unknown | Discover a new path | Discover new dead ends |

I'd argue that in the current pre-paradigmatic phase, we should value exploration of Unknown-Useless paths as highly as exploration of Known-Useful Paths. This is especially true because it is unclear if Known-Useful paths are *actually* Useful! Thus, my focus will be on the bottom row - the Unknowns. But what does it matter if we aim for Known or Unknown paths and how should we evaluate the value of the two strategies?

My intuition is that aiming for Unknown paths, my probability of ending up in each cell is something like:

Useful | Useless | |

Known | 0.1 | 0.2 |

Unknown | 0.1 | 0.6 |

So I expect about a 10% success rate for the ideal outcome, about an equal chance to end up on what most people following a set study path would end up on, and then a six times greater chance than that to go down a dead end path that was legitimately underexplored, which is also a good thing! My greatest worry is that any given dead end I explore will turn out to have been an obvious dead end to my peers in about a quarter of the cases, and that this outcome feels as likely to me as doing something Useful at all. However, I think focusing on the Unknowns is still worth it for the increased chance of finding Unknown-Useful outcomes.

In contrast, if we compare to aiming for Known paths I think I'd end up with the following probabilities:

Useful | Useless | |

Known | 0.8 | 0.1 |

Unknown | 0.01 | 0.09 |

Cause it's hard to miss the target when you are on rails, but also nearly impossible to explore!

Now these probabilities say more about my brain, my self-assessment and my model of how minds work, than about the actual shape of reality. It's a way to convey intuitions on why I'm approaching alignment studies the way I am. Maybe I'm wrong and people explore just fine after focusing on existing methods, and then we can just reframe the above thinking as one of the paths I'm exploring -- Namely, the path of explicit exploration.

Either way, using this framework of Known-Unknown and Useful-Useless paths, highlights that marking the paths you take is a key item. It's an exploration of solution space, and we want to track the dead ends as much as the promising new avenues, or else we'll be duplicating work within the community. Thus, by writing down my research path others may retroactively trace back what *definitely didn't work* (if I end up on a Useless path) or *how breakthroughs are made* (if I end up on the Unknown-Useful path).

*Thus I commit to exploring new paths for researching the alignment problem and documenting my progress along the way.*

One of the errors I dread the most is to get sucked in to one research path with one specific problem and lose track of the greater problem landscape. Instead, I want to be sure I have an overview, a narrative, a map -- an overarching *paradigm* that I am working with. It should show how each problem I'm studying fits into an overall model of solving alignment. Honestly, of the three research principles, this is the only one I'd strongly argue for general adoption by all new alignment researchers:

*Prioritize getting a complete view of the problem landscape and how your work actually solves alignment.*

This is important for two reasons:

First, by keeping a bird's eye view of the interrelation of the major subproblems of alignment, your mind is more likely to synthesize solutions that shift the entire frame. There is a form of information integration that a brain can do that involves intuitive leaps between reasoning steps. Internally it feels like your brain has pattern matched into the expectation of a connection between A and B, but when you actually look, there are no obvious steps connecting the two. This in turn sparks exploration of possible paths that might connect A and B. Sometimes you find them and sometimes you don't, but either way, I suspect this type of high-level integrative cognition is key to solving alignment. As such, a bird's eye view of the problem should be at the front of one's mind every step of the way.

Secondly, with a map in hand between us and the destination point of solving alignment, you will be able to measure your progress so far. By having a coherent model of how each of your actions plays a role and can matter to the eventual outcome, you won't get lost in the weeds staring at the pretty colors of high dopamine-dispensing subproblems. Therefore, if someone asks me, "Shoshannah, why did you spend the last month studying method X?", then I should be able to coherently and promptly answer how and why X may matter to solving alignment from start to finish.

*Thus I commit to working toward a coherent paradigm of AI alignment in which I can situate my work, explain how it contributes to solving alignment, and measure my progress toward this goal.*

For my 6 months of AI alignment studies, I will aim to be transparent and explorative in my work while constructing and situating my actions in a coherent paradigm of the alignment problem. With this approach the journal entries of this sequence will be an exercise in epistemic humility.

Wish me luck.

Discuss]]>

*Epistemic status: Whimsical*

*Major spoilers for Madoka Magica, a show where spoilers matter!*

Meet Kyubey. Kyubey is a Longtermist.

In the Madokaverse, changes in human emotion are, somehow, net-positive in the creation of energy from out of nothing. The Incubators (of which Kyubey is one, pictured above) are an alien species who've discovered a way to farm human emotions for energy.

Most of the Incubators don't feel emotion, and the few that do are considered to be mentally ill. But humans are constantly leaking our juicy, negentropy-positive feelings all over the place. With human angst as a power source, it's possible to prevent the heat death of the universe!

Do the math, people. The suffering of a few teenage girls is nothing compared to pushing back the heat death of the universe.^{[1]}

And this isn't just some Omelas situation where the girls get nothing out of it. They get *wishes*! Who could object to a cause this noble?

*If you want to see Homura kicking ass, you could watch up to 2:22 before reading on.*

There's something subtle here—something to notice confusion about, even—where is she getting all these guns from?

Remember: *Homura's power is time manipulation.* As one commenter puts it:

This is hauntingly sobering when you consider that Homura's magical ability has nothing to do with guns, only with time manipulation. That means all those tens of thousands, hundreds of thousands of pounds of explosive material and weapons arms weren't just made from nothing like Mami's guns were- they were individually tracked down and gathered, one after the other, by one little girl.

How many hundreds of repetitions did it take to find them all, every time making a new doomed timeline? How many thousands of hours did she spend looking for where to get them from, and how many failed attempts finding the most effective way to arrange them?

Rationalists have a name for this kind of determination: Having something to protect.

When Kyubey creates a magical girl, he offers them an atomic contract: they gain a sparkly transformation and fight witches for the rest of their life, and in exchange, they're granted a wish.

There's a minor risk here: Kyubey can't actually stop this process: the wish will be granted whether he likes it or not.

(I sure hope the mesa-objective pursued by human girls is the same as the outer objective (negentropy) pursued by Kyubey. There are no possible ways this could go wrong)

The Incubators were reckless. I'm glad humans would never apply large amounts of optimization power without guarantees for how it's aimed.

Hopefully you've already seen the anime (otherwise, sorry for all the spoilers you just read!) but if you haven't, go watch it now. It's great, and incidentally chock-full of fables like these. (For a bonus fable on Kyoko and the complexity of wishes see *Ep7, 8:05 - 12:14**.)*

If you *have* already seen the anime and want to read something with similar themes, I would recommend Qualia The Purple.

^{^}Though it isn't spelled out in the show, humans appear to be the only species that has feelings, so depending on whether you're a positive or negative utilitarian, a universe full of emotionless beings may or may not be a compelling vision for you.

Discuss]]>

Brun's theorem is a relatively famous result in analytic number theory that says the sum of the reciprocals of the twin primes converges to a finite value. In other words, we have

for some finite constant . This is in contrast to the same sum taken over *all primes*, which is divergent:

In this post, I'll use Brun's theorem as an illustration of sieve theoretic arguments in analytic number theory. I'll try to explain relevant results as I go along to minimize the background necessary to understand the arguments, but some background in real analysis and number theory is needed to understand the post. If you don't have such a background, most of the post will probably be gibberish.

I'm writing this post mostly because I think there's some lack of good explanations of sieve theory in general and the Brun sieve in particular. Hopefully this post will be helpful to a handful of people who are interested in or trying to understand such matters.

Note that in the post I'll not always mention that a sum or a product runs over the prime numbers explicitly. If the sum or product is indexed by the letter , you should assume that it runs over the primes and not e.g. over the natural numbers. Sometimes will run only over *odd* primes, because there is a degenerate case with the prime when we work with twin primes coming from the fact that and are in the same residue class modulo . This will often be obvious from the surrounding context.

First, let's discuss some background results that will be useful to know throughout the post.

The prime number theorem says that the number of primes less than , denoted , is well approximated by . Concretely, it says that

Roughly, the prime number theorem says that the density of the prime numbers "around " is roughly when is large. This is a rather difficult theorem to prove and we won't actually need it to prove Brun's theorem. However, the result will be useful to know in heuristic arguments for why we might expect the theorem to be true, and to motivate our method of proof.

We will also need the Mertens theorems. These theorems are "weaker" versions of the prime number theorem. Specifically, we will need to know the fact that

This is a stronger, quantitative version of the result that the sum of the reciprocals of the prime numbers is divergent. This theorem is not as difficult to prove. Here is a possible proof strategy: we know that each prime occurs in the prime factorization of "roughly" times. So we have a rough approximation

where the sum over runs over prime numbers - I'll stick to this convention for the rest of the post for the sake of brevity. On the other hand, Stirling's theorem gives the approximation . Combining these immediately gives

once we take account of the error term in this approximation.

To pass from this to the sum we care about, we employ partial summation. This is the discrete analog of integration by parts. The idea is to write the sum we care about as

where is the indicator function of the primes, and then use summation by parts, which allows us to shift the "differencing operator" from acting on to acting on . Approximating the first difference of by its derivative gives us the main term

which we can now approximate using the result from above as

where the last approximation follows from replacing the sum by an integral and noting that . If we carefully keep track of all the error terms in the approximations, we again recover the precise result that .

That's all the background knowledge we'll need for the post. Moving on:

Let's think about why we might *expect* Brun's theorem to be true.

We know from the prime number theorem that the density of the primes around is , so if we cheat and assume that the events of being prime and being prime are "independent", the density of the twin primes should be roughly . So we should expect

where counts the twin primes less than or equal to . If so, we can simply turn the crank of partial summation as we did before: letting be the set of twin primes and its indicator function, we compute

This final sum is convergent. One way to see it is by an integral test: if we consider the integral

then the substitution turns the integral into

which is obviously finite. So the initial sum is convergent as well.

This heuristic suggests that Brun's theorem should be true if there is "nothing wrong": if there's no unexpected correlation between the events of being prime and being prime. The task of the proof is therefore to show that this correlation can't get bad enough to make this sum divergent.

One key insight is that we actually have some substantial amount of room to work with in this argument: we don't actually need to get anywhere close to . If we could show, for instance, that

for all sufficiently large , which is a much weaker bound than what the heuristic suggests; we would still prove that the sum we care about is convergent. This fact is what makes Brun's theorem a relatively "easy" result: we don't actually need to show that there's strictly *no* correlation between being prime and being prime, just that the correlation doesn't get too bad.

Now that we understand what we have to do, it's time to think about concrete proof strategies. Our goal is to prove a "nontrivial" upper bound on the twin prime counting function . We know that the key objective is to bound the dependence of the events of being prime and being prime. We need to make something work with these ingredients.

Given that our goal is to prove an *upper bound*, we'll still have succeeded if we find a set that *contains* the twin primes and for which we can prove a similar result, as if we have an upper bound on we'll obviously have an upper bound on as well. This suggests one immediate way to attempt a proof: the sieve of Eratosthenes.

The sieve of Eratosthenes characterizes the primes by the following: a prime is a number that's either not divisible by, or gives when divided by, . In other words, we obtain a sequence of sets which consists of the integers not divisible by the first primes (except the primes themselves), and as grows this gives us an increasingly refined superset of the prime numbers. It's called a sieve because we're gradually "sieving out" composite numbers from the integers.

To apply this to our problem, we need to modify the sieve a little bit, as we care about twin primes. Instead of sieving out integers that are for each prime , we'll sieve out integers that are *either* *or* . This is because the bigger prime in a tuple of twin primes can't belong to either of these congruence classes when is chosen to be smaller than the numbers in the tuple.

Now, we know that for a prime , the number of natural numbers less than in *any* residue class mod is equal to . This approximation is quite good when is much less than . So we can roughly think that applying the sieve for a prime removes roughly a fraction of the remaining integers for odd primes, and for the prime . For this to work nicely it's also more convenient to sieve out by the entire congruence class, so we don't make an exception for the case when we're sieving out a twin prime as we would with the traditional sieve of Eratosthenes.

Given the above condition, we can't sieve by all primes up to as obviously this will leave us with nothing left, but if we sieve by the primes up to , we'll only have removed at most twin primes, which is not going to be important as the upper bound we want to prove is much bigger than . So we can ideally imagine to get a result like

where the approximation uses the fact that when is small and the last equality uses the Mertens' theorems. Note that here runs over the odd primes. If we could prove this, we would be very happy.

However, in fact we can't prove this, and the reason is that we're again assuming that the events we're dealing with are independent. If we could *prove* the independence in some formal sense, the argument would be fine, but in fact this turns out to be very difficult to prove when we're taking all primes up to in the product expansion.

The reason is that the fundamental result about patching together different congruence classes modulo different primes, the Chinese remainder theorem, only allows us to control the number of solutions to a system of congruences over different prime moduli when the product of those primes is substantially less than : concretely, it means the exact asymptotic we can give on the above expression is only

It turns out that another implication of the prime number theorem is that

so we expect the product of all primes up to to behave like the exponential of . When we apply this to the primes less than , we get that the product is , which is much larger than asymptotically. So the error term becomes much larger than the main term and the approximation fails.

The best we can hope out of this is to take up to or so. This is sometimes called the Legendre sieve, and gives us the bound

which turns out to be too weak to do what we want. Indeed, this is even weaker than the prime number theorem bound and we know the reciprocals of the prime numbers have a divergent sum, so the Legendre sieve gets us nothing in this situation: it's far too weak.

We obviously need to be more clever. However, this shouldn't be *too* hard: intuitively, it should be possible for us to go past the threshold, because it's not that we don't understand *anything* about primes that are bigger than : the only information we don't have is about the joint correlations that involve a very large number of primes, and we understand the low order correlations very well even at thresholds much bigger than , because those only involve the product of a small number of primes. For instance, up to we understand all correlations that only involve or less primes very well from the same Chinese remainder theorem bound.

This suggests the following proof strategy: since higher order correlations are causing all the problems, we should try to find some way to decompose the quantity we want to calculate into "lower order" and "higher order" terms, and then see if we can bound the "higher order" part well enough that we can get some nontrivial result.

The Brun sieve is essentially an implementation of the vague idea discussed above.

Let's say that event for a prime is defined as above: it's the event of being either or modulo . Then, we can see our above computation as noting that, ignoring the primes less than (which pose no problems, as discussed above), we have

where the exponent denotes set complements. If we assume the events are actually independent, we can do a computation that looks like

and since we know we would be able to do the same proof we did above. However, as we know, the independence assumption is actually wrong, so this doesn't really work.

However, there is an *unconditional* way to simplify expressions of this kind that doesn't assume independence: to use inclusion-exclusion. This looks like

Overall it's a big mess, but the hope is that this is exactly what we need to use our better understanding of the low order correlations between the congruence classes of different primes: inclusion-exclusion gives us a series expansion for the quantity we want that's increasing in the complexity of the correlations we want to understand!

So our hope might be the following: we truncate this sum at some threshold , as in, we truncate at the term that involves the correlations of primes. We pick to be even, as inclusion-exclusion has the property that the partial sums are always upper bounds for even and lower bounds for odd . We then try to understand this truncated sum better.

The square root is not ideal for this, as we already don't understand three prime correlations when we work with the square root. So let's also take the threshold at which we cut off the primes as a free parameter . In the above, we were using , but we want to generalize this now.

One immediate observation is that if we cut the sum off at , we only have a good understanding of correlations that involve at most primes. To be safe, let's truncate the sum at (or choose to be) the even number closest to where is a small number that we'll fix later. In this case, the Chinese remainder approximations will be very good, and we don't have to worry about the failure of independence anymore. (I'll equivocate between the even integer and the real number sometimes - rest assured that this causes no problems for the argument and everything goes into the error terms.) So we'll get the sum

after truncation. It's important to keep track of the error term of this properly, so we can note that the error we make whenever we approximate the probability of an intersection with an independence assumption is : this is because each correlation involves at most primes and each prime is , so the product is bounded from above by which is equal to given our choice of . Since when the number of such approximations we have to make is bounded from above by , we're safe as long as is small. For instance, will be enough for the error to be negligible relative to the main term here.

Now that we know the error term is negligible, so long as is small, let's think about approximating this sum. Bounding from above with the triangle inequality gives very poor results here because the sum has a lot of cancellation. However, we can use a trick: we know that the *untruncated sum*, i.e. the sum for , is equal to the infinite product

This is the expression we were working with earlier when assuming independence, so it makes sense to bring it back here. Instead of bounding the sum above directly, we'll write it as this "main term" plus an error term that comes from the "high order correlations" when the number of primes exceeds . This fits in naturally with our proof strategy: try to bound the contribution coming from higher order correlations.

This will give us

The triangle inequality works much better on the error term now. Indeed, it's fairly straightforward to see that if we recall the definition

then the "error term", as in, all terms aside from the product term here, can be bounded in absolute value by

This is an elementary exericse and I leave it to the reader to see why it's true. Taking this for granted, though, we can approximate this sum up to a constant by just its first term if