Work done @ SERI-MATS.
Evaluating interpretability methods (and so, developing good ones) is really hard because we have no ground truth. Or at least, no ground truth that we can compare our interpretations directly against.
The ground truth of a model's behaviour is provided by that model's architecture and its learned parameters. But, puny humans are unable to interpret this: it's precise, in that it accurately explains the model's behaviour, but it's not interpretable. On the other end of the spectrum we have something like "This model classifies cats" – a statement that is really easy to interpret, but lacks something in the way of precision.
Precise <---------------------------------> Interpretable
Imagine two interpretations, each generated by a different method with respect to the same model (say, a cat classifier). Method A indicates that the model has learned to use ears and whiskers to identify cats. Method B indicates that it uses eyes and tails. Assuming both are easy to interpret, can we tell which method is most precise? Which most faithfully represents what the model is truly doing?
If we had a method that reconciled precision and interpretability, how would we know?
Well, we can perform sanity checks on the interpretability methods, and throw away any that fail them. This seems good – it's at least objective – but it only really allows us to throw away obviously bad approaches. It doesn't say anything about what to do when sane interpretability methods disagree.
We could also look at the interpretations and see if they appear sensible to us. This is a widely used approach (Zeiler et al., Petsuik et al., Fong et al., many many more), and I think it's a terrible idea.
- We've made some new interpretability method that is supposed to help us understand which words are used by a language model to identify hate speech in tweets. To see if it works properly, we compare the words highlighted by this interpretability method to the actual hateful words in the tweet. It gets them right! Our new interpretability method works!
- NO! We have fallen prey to a terrible assumption: that if a model performs well, it has learned to use the same features that a human would use. How we would perform a task is not the ground truth. How the model actually performs a task is, but we don't know that – it's what we're trying to find out!
- We use gradient descent to optimise the input to a model such that it maximises the activation of a particular node, layer, or logit. This works, but results in a really noisy input that doesn't make sense to us – it seems like adversarial noise. So, we regularise the input, perturb it intermittently during optimisation, constrain it to the training data distribution, and voila – we have a nice optimised input that makes sense! We found a fur detector!
- NOOOOOO! We've optimised the input to maximise the some output, sure. But we've also optimised it to maximise how much we like it. That's not what we wanted! That has nothing to do with what the model has actually learned, and how sensible an interpretation seems to us has no relation to the ground truth.
I'm being a bit dramatic. These kinds of approaches can be useful, and god knows I love a good feature visualisation, same as anyone. But I'm worried about using stuff like this to determine how good our interpretability methods are. It's not an objective evaluation.
A small idea: what if we did have access to the ground truth? If we had a small, simple model that we completely understood (I'm looking at you, mechanistic interpretability people), we could use it as a truly objective benchmark for other interpretability methods. (This is super easy for model agnostic saliency mapping – just use summation in place of the model, and then the ground truth saliency of each input element is exactly that element itself. If your saliency mapping method isn't exactly the same as the input, it's not working perfectly – and moreover, you can see exactly where it's failing.)
Makes me think about what precisely we want from interpretability, since that seems relevant in grounding it.
As a first go at the answer, it seems to me that we want to know what assumptions the model relies on, in order to give correct results, as well as what properties are satisfied by intermediate results in the network's calculations. This is useful because we can then look at under what conditions these assumptions are known to fail or succeed.
For instance with the cat classification, why would we want to learn that it "has learned to use ears and whiskers to identify cats"? Well, we might translate this into the assumption "something is a cat iff it has cat ears and whiskers". There's two ways this could fail: if some non-cat entity has cat ears and whiskers (for instance because someone is trying to trick the network), or some cat does not exhibit cat ears and whiskers (for instance maybe they are hidden by the camera angle). These failures tell us important things about what would happen if we used the network for high-stakes real-world decisions.
However, there are a number of problems with my stab at the goal of interpretability. First, there will often be multiple different sets of conditions that can make a model valid, rather than a single unique one. Second, the conditions themselves will in reality be so incredibly complex and plentiful that they need interpretation. Third, there is an actual answer to what assumptions the network makes, and it is a useless answer: it makes the assumptions that the data fed into the network is IID with the same distribution as the training data and that the cost of misclassifications is proportional to the loss.
I think these flaws point towards that when we do interpretability, we more want to impose some structure on the network. That is, we want to find some set of conditions that could occur in reality, where we can know that if these conditions occur, the network satisfies some useful property (such as "usually classifies things correctly").
The main difficulty with this is, it requires a really good understanding of reality? This is something humans sometimes have good ability to figure out, but it is hard to make an automated tool for this. In fact it seems like an AGI-complete problem: understand in depth how reality works and distill this into a small set of useful conditions that can be applied to deduce things about how systems perform across a wide variety of conditions.
... 😅 which might be indicative that I have made an overly strong requirement somewhere along the line, because surely one can do useful interpretability without an AGI.
There we go!
So, one item on my list of posts to maybe get around to writing at some point is about what's missing from current work on interpretability, what bottlenecks would need to be addressed to get the kind of interpretability we ideally want for application to alignment, and how True Names in general and natural abstraction specifically fit into the picture.
The OP got about half the picture: current methods mostly don't have a good ground truth. People use toy environments to work around that, but then we don't know how well tools will generalize to real-world structures which are certainly more complex and might even be differently complex.
The other half of the picture is: what would a good ground truth for interpretability even look like? And as you say, the answer involves a really good understanding of reality.
Unpacking a bit more: "interpret" is a two-part word. We see a bunch of floating-point numbers in a net, and we interpret them as an inner optimizer, or we interpret them as a representation of a car, or we interpret them as fourier components of some signal, or .... Claim: the ground truth for an interpretability method is a True Name of whatever we're interpreting the floating-point numbers as. The ground truth for an interpretability method which looks for inner optimizers is, roughly speaking, a True Name of inner optimization. The ground truth for an interpretability method which looks for representations of cars is, roughly speaking, a True Name of cars (which presumably routes through some version of natural abstraction). The reason we have good ground truths for interpretability in various toy problems is because we already know the True Names of all the key things involved in those toy problems - like e.g. modular addition and Fourier components.