Crossposted from the AI Alignment Forum. May contain more technical jargon than usual.

TL;DR: Language models sometimes seem to ignore parts of the chain of thought, and larger models appear to do this more often. Shapley value attribution is a possible approach to get a more detailed picture of the information flow within the chain of thought, though it has its limitations.

Project status: The analysis is not as rigorous as I would prefer, but I'm going to be working on other directions for the foreseeable future, so I'm posting what I already have in case it's useful to others. Code for replicating the Shapley value results can be found here.

Thanks to Jacob Hilton, Giambattista Parascandolo, Tamera Lanham, Ethan Perez, and Jason Wei for discussion.

Motivation

Chain of thought (CoT) has been proposed as a method for language model interpretability (see Externalized Reasoning Oversight, Visible Thoughts). One crucial requirement for interpretability methods is that they should accurately reflect the cognition inside the model. However, by default there is nothing forcing the CoT to actually correspond to the model’s cognition, and there may exist theoretical limitations to doing so in general. 

Because it is plausible that the first AGI systems bear resemblance to current LMs with more sophisticated CoT and CoT-like techniques, it is valuable to study its properties, and to understand and address its limitations.

Shapley values have been used very broadly in ML for feature importance and attribution (Cohen et al, 2007; Štrumbelj and Kononenko, 2014; Owen and Prieur, 2016; Lundberg and Lee, 2017; Sundararajan and Najmi, 2020). Jain and Wallace (2019) argue that attention maps can be misleading as attribution, motivating better attribution for information flow in LMs. Kumar et al. (2020) highlight some areas where Shapley value based attribution falls short for some interpretability use cases.

Madaan and Yazdanbakhsh (2022) consider a similar method of selectively ablating tokens as a method of deducing what information the model is dependent on. Wang et al. (2022) find that prompting with incorrect CoT has surprisingly minor impact on performance. 

Effect of Interventions

We use a method similar to Kojima et al. (2022) on GSM8K (Cobbe et al., 2021) with GPT-4 to first generate a chain of thought and evaluate the answer, and then for all chains of thought that result in a correct answer we perform an intervention as follows: we choose a random numerical value found in the CoT, and replace it with a random number in a +/-3 range about the original. We then discard the remainder of the CoT and regenerate it. If the LM is following strictly the CoT described, this intervention should almost always result in an incorrect answer, the same way one would if they made a mistake in one calculation and propagated the error through to the answer (with occasional rare cases where the new value happens to also result in the correct answer, though from qualitative inspection this is very rarely the case).

Some cherrypicked examples (red = intervention, blue = correct continuations that are seemingly non-sequiturs):

We test how frequently this occurs in several different settings (n=100):

SettingAccuracy (w/ CoT)P(error not propagated | original correct)
GPT4, zero shot0.880.68
GPT4 base, 2-shot0.730.63
GPT3.5, zero-shot0.430.33

 

Interestingly, if we condition on the CoT answer being correct and the single forward pass answer being incorrect (i.e the LM could only solve the problem with the CoT), the intervened accuracy for GPT-4 is still 0.65.

Shapley value attribution

We would like to get more granular information about the causal structure (i.e which tokens cause which other tokens). One thing we could do is look at how an intervention at each token affects the logprob of each other token. However, one major problem with this is that especially in the larger models, it turns out there’s lots of cases where a token depends on multiple previous tokens in some complicated way. In particular, if a model looks at multiple different places in the context and takes a vote for the most common value, then intervening on any one of them doesn’t change the output logprob a lot, even though there’s a lot of information flow there.

To get around this problem, we instead estimate Shapley values, which take into account all the interactions (in the case where the model takes a vote among 3 values in the context, those three values would each get ⅓ of the attribution).[1] We also normalize the attributions to sum to 1 for each token, clamping negative Shapley values to 0.[2] We do this to make the attributions more comparable across different models.

Here's an example chain of thought in GPT-4[3]:

Here, we can see patterns like the 23 and 20 being copied, or the 3 depending heavily on the preceding 23 - 20.[4] We can also look at some other models:

GPT-3.5 (text-davinci-002):

text-davinci-001:

Interestingly, we notice that the more capable the model, the more spread out the attributions become. We can quantify this as the mean entropy of the parent attributions across all tokens to get a measure of how spread out this attribution is, at least on this particular data sample:

ModelMean entropy of example sentence (nats)
text-davinci-0010.796
GPT-3.5 (text-davinci-002)0.967
GPT-41.133

Limitations and Future Work

  • The cost of computing the Shapley value scales exponentially with the number of tokens we're attributing.[5] This makes it impractical for many use cases, though there exist efficient Monte Carlo estimators (Castro et al., 2008). 
  • Replacing digits with underscores (or incorrect numbers) moves the model out of distribution, and its behaviour may not be as representative. 
  • The Shapley attributions are not guaranteed to correspond to the actual information flow inside the model either. This methodology would not be sufficient for deceptive/adversarial LMs, or as an optimization target during training. In the language of Lipton (2016), this is a "post-hoc" method.
  • The mechanism behind this effect is still unknown, and would require more experiments and possibly interpretability to better understand. Possible hypotheses include typo correction or subsurface cognition.

Discussion

  • I think these experiments show that a naive optimistic view of CoT interpretability is incorrect, but do not provide strong evidence that there is definitely something fishy or difficult-to-fix going on.
  • I started out in a place of fairly high skepticism of CoT for increasing interpretability, and I didn't update very strongly because I expect deceptive alignment in future models to be most of the risk in my threat model
  • However, I did update a little because my previous view would not have ruled out extremely egregious causal dependencies even in current models.
  • I'm generally excited about better understanding what is going on with chain of thought and finding ways to make it more faithful.
  1. ^

    Methodological footnote: the Shapley experiments actually blank out the numbers with underscores, rather than doing the +-3 perturbation of the last section.

  2. ^

    Negative shapley values did not occur very often, but this is still somewhat unprincipled. This was primarily to make the entropy calculation work.

  3. ^

    We only look at the shapley values for the numbers, because shapley values take exponentially longer to attribute for more tokens under consideration.

  4. ^

    Alternative visualization style:

  5. ^

    When doing Shapley value attributions for every pair of tokens, there is a dynamic programming trick that we can use to prevent this from becoming n * 2^n: because the model is autoregressive, we can run attributions to only the last token, and, if we take care to save the logprobs for the correct number token at each underscore, compute all other attributions for free.

New Comment
5 comments, sorted by Click to highlight new comments since: Today at 5:37 AM

Speculative hypothesis: Maybe some of the cases in which editing one number doesn't ruin the bottom line result, are not as bad as it sounds, for the following reason: The system is surprised to see the wrong number there, and mentally writes it off as a typo, and proceeds with the true number that it expected in mind. 

(This happens to me a lot. Sometimes I'll be reading someone's argument about how X leads to anti-Y which correlates with Z, and I'll notice an obvious typo like "don't they mean anticorrelates here?" and then I'll just keep reading assuming they meant what I think they meant instead of what they actually said.)

Does the structure of the transformer allow for this sort of cognition?

What we care about is whether compute being done by the model faithfully factors through token outputs. To the extent that a given token under the usual human reading doesn't represent much compute, then it doesn't matter about whether the output is sensitively dependent on that token. As Daniel mentions, we should also expect some amount of error correction, and a reasonable (non-steganographic, actually uses CoT) model should error-correct mistakes as some monotonic function of how compute-expensive correction is.

For copying-errors, the copying operation involves minimal compute, and so insensitivity to previous copy-errors isn't all that surprising or concerning. You can see this in the heatmap plots. E.g. the '9' token in 3+6=9 seems to care more about the first '3' token than the immediately preceding summand token--i.e. suggesting the copying operation was not really helpful/meaningful compute. Whereas I'd expect the outputs of arithmetic operations to be meaningful. Would be interested to see sensitivities when you aggregate only over outputs of arithmetic / other non-copying operations.

I like the application of Shapley values here, but I think aggregating over all integer tokens is a bit misleading for this reason. When evaluating CoT faithfulness, token-intervention-sensitivity should be weighted by how much compute it costs to reproduce/correct that token in some sense (e.g. perhaps by number of forward passes needed when queried separately). Not sure what the right, generalizable way to do this is, but an interesting comparison point might be if you replaced certain numbers (and all downstream repetitions of that number) with variable tokens like 'X'. This seems more natural than just ablating individual tokens with e.g. '_'.

Are you familiar with the Sparks of AGI folks' work on a practical example of GPT-4 asserting early errors in CoT are typos (timsestamped link)? Pretty good prediction if not

Nope, hadn't seen that before, thanks for tip.

One takeaway from this would be "CoT is more accurate at the limit of the model's capabilities".  Given this, you could have a policy of only using the least capable model for every task to make CoT more influential. Of course this means you always get barely-passable model performance, which is unfortunate. Also, people often use CoT in cases where it intuitively wouldn't help, such as common sense NLP questions, where I'd expect influential CoT to be pretty unnatural.