I'm very excited about research that tries to deeply understand how neural networks are thinking, and especially to understand tiny parts of neural networks without too much concern for scalability, as described in OpenPhil's recent RFP or the Circuits thread on Distill.
I want to give some detail on the kind of research I'm most excited about and why.
(Disclaimer: These are off-the-cuff thoughts. I'm not at all an expert in this area, just someone who cares a lot about alignment and has opinions.)
Why I'm excited about interpretability
I'm very scared of powerful AI doing something deliberately destructive that no one wanted it to do. If we had a reasonably complete story about how a neural network was thinking, then I doubt this failure mode would take us by surprise---we'd be able to see that the model was only behaving well on the training distribution because we were watching, and understand that the model was capable and inclined to cause harm when deployed. I think that even a partial understanding could provide a huge amount of clarity on open disputes about how models will generalize and when they may pose risks.
So the simplest positive case for interpretability is that it gives us a much clearer warning about an impending problem. That may or may not be enough for us to get our act together and fix the problem, but it's at least a start.
Beyond that, interpretability plays a central role in many proposals for training safe AI. For example, we might use interpretability to help identify cases where the model would behave badly, to help ensure that the model never develops deceptive intent, or to help understand what the model knows about the world so that we can better evaluate its outputs. We don't really know which of these approaches will pan out, but the fact that interpretability keeps coming up makes it seem important beyond just a warning shot.
Some of these approaches would put more "strain" on an interpretability method. If we optimize models so that they don't appear to think certain kinds of thoughts, then gradient descent will find a way to hide those thoughts. I don't know if interpretability techniques would ever be strong enough to stand up to this kind of optimization pressure, but I think interpretability is incredibly useful anyway---it still gives you a key warning shot, and many of the potential applications in the training process wouldn't create as much adversarial pressure.
It's possible for models to eventually hide their cognition even if we don't apply gradient descent. But I think this may be extremely hard; that the process of doing so is likely to itself be visible if we are careful; and that even if transformative AI is smart enough to obfuscate its thoughts we would still learn a huge amount from studying dumber models.
I'm significantly more interested and optimistic about mechanistic interpretability than I was 4 years ago. So from my perspective Chris Olah, who has championed this perspective during the interim, gets a lot of epistemic credit here.
Why I'm not worried about scalability
I think I basically totally agree with the OpenPhil RFP, and in particular think the most exciting work is getting an extremely detailed mechanical understanding of a tiny part of a neural network. If anything I'd lean slightly further in that direction than Chris does in the RFP.
Expanding on that: it seems to me that the current bottleneck for interpretability is that we almost never have a good understanding of what's going on, not that we have good methods that scale poorly. I think people should mostly not worry about scalability while we try to get really good at understanding what's up with small pieces of neural networks.
That's partly an aesthetic judgment about what the conceptual bottlenecks are, but I think there are also some concrete reasons not to worry too much about scalability:
- Worst case, in the long run I think there's a very good chance that it would be feasible to automate the literal steps a human takes in order to understand a neural network, and that doing so will be cheap enough to apply to even the largest models.
- Each time we do this exercise I think we will learn much more about how to do it well and about the general structure of neural networks. I think there's a good chance that if we understand how to do this task well we'll also understand how to write a scalable algorithm.
- Maybe if we had many examples of understanding small circuits, and it wasn't getting any easier and we weren't learning anything general, then I'd suggest that we should be focusing on scalability. But instead I feel like this work is at an extremely preliminary stage, where we have very few examples of deeply understanding anything other than low-level vision and could go a lot further even on the examples we do have.
- I have some general intuition that it's very unlikely for this kind of task to be bottlenecked on computational issues---if we're in that world, then in some sense interpretability would be one of the easiest parts of ML to work on, and so I would expect extremely rapid progress.
I think "fully understand a neural network" is a good aspirational goal, which I think is also mostly bottlenecked on "deeply understand small pieces effectively."
Comparison to Circuits
The Circuits thread on Distill is probably the best example of work I find exciting.
I think the most exciting aspect of this thread is the "Artificial Artificial Neural Network" described in Curve Circuits, and in particular the preliminary results on replacing the "natural" curve circuit with their hand-coded version (See "Finally, a preliminary experiment..."). They consider a set of neurons that appear to be doing curve detection, find that zeroing those activations reduces accuracy by 3.3%, and that replacing them with a hand-written curve detection algorithm recovers half of the loss.
I feel like "understand pieces of cognition well enough that you can replace them without degrading performance" is the right game to be playing, and that the evidence provided by successfully doing the replacement would be much stronger than anything else that we think we know about neural networks.
My biggest reservations with that work are:
- The experimental evidence that the artificial circuit works well is preliminary. The sample size is small enough that it could just be noise; I doubt the whole effect is noise, but the real effect size could at least be very different.
- I think that having the artificial circuit work at least as well as the original would be much more informative than being 50% as good; this more ambitious goal also seems tractable with a good enough understanding. It could be the case that zeroing out a neuron is extremely harmful for a network (much worse than retraining from scratch without those neurons), such that even approximating some crude high-level statistics would be 50% as good.
- It seems plausible that "curve detector" is one of the easiest circuits, and that replacing even a single random neuron in an image model would be much harder. I wouldn't want to address that by trying to jump straight to a random neuron, but it's something you'd learn about as you did more examples.
Most of all, I think that they could predictably get much cleaner results by continuing to spend time on the problem---they didn't put much effort into this experiment, and they mention several obvious ways to do much better (e.g. their version of the circuit is in grayscale instead of color).
There were several ways I think that the authors might have been making life unnecessarily hard on themselves:
- They implemented their algorithm using exactly the same architecture as the original, setting individual weights by hand. But I think it would be nearly as good to replace a neuron with an arbitrary (efficient) function of the earlier neurons. This makes it much more plausible that you will recover >100% of the performance and hence have confidence that you understand the high order bits of what the model is doing.
- Chris avoided looking at the model while initially implementing the replacement. This makes a lot of sense as a simple way to reduce the probability of cheating, but in the long run I think you want to be going back and forth, analyzing errors and divergences between your version and the original in order to learn more and more about it. That requires making a more subtle judgment about whether you "understand" the algorithm you've written down, but I think that's kind of in the nature of the game.
- They were looking for motifs and patterns that could allow them to simplify the network and understand bigger pieces of it at once, but I think it would be fine just to get a great grip on e.g. a single orientation of curve detector. (This makes it harder to confirm that you've preserved performance, but it still looks tractable if you do a careful evaluation.) I think the biggest implication is that future work should probably be happy understanding even smaller parts of neural networks since we won't always have so much equivariance.
If I were recommending someone an applied alignment project to get started in the field, a strong candidate would be trying to "finish the job" on curve detectors, before trying to apply the same treatment to another similarly complex neuron. My biggest concern would be that this work would be quite hard and not adequately appreciated, but given increasing interest in alignment and consensus about the centrality of interpretability I'm less concerned about that than I would have been a few years ago.
I'm somewhat more excited about doing these analyses for large transformers trained as LMs. I think there are good reasons to expect the same basic approach to work, and that if anything the existing problem is more likely to be usefully analogous to the long-term problem. That said, I think that there are a few tricky things in generalizing this approach to transformers and for people new to the field it may be better to try to follow more closely with the existing work on vision models (since I expect the core difficulties to be extremely similar).
In general I think that this work is more likely to be tractable for smaller models, and I'm intuitively a bit skeptical of Chris' "Valley of Confused Abstractions". I'm scared about projects shifting prematurely to larger models because the most interpretable parts of the model get easier and easier to understand even while the least interpretable parts (where in some sense we should be aiming) are getting harder and harder to understand. That said, I have almost no first-hand experience doing interpretability, and Chris has much more experience doing interpretability including with tiny models---so I think there's a good chance that projects on e.g. an MNIST model would just be exercises in clarifying what it really means to e.g. "understand" a linear regression, rather than expecting to get very satisfying results.
I don't really expect to be able to tell "simple" stories about individual neurons in general. I think many of them might be complicated messes involving unfamiliar concepts where a human can only understood one small aspect of a neuron's behavior at a time. And often a neuron will only make sense in the context of a bunch of other neurons.
The RFP talks about polysemanticity as something we'd like to avoid, but I'm a bit skeptical on that point---it seems to me like in powerful systems neurons will often be incredibly polysemantic, and indeed to stretch the concepts such that "polysemantic" doesn't really make sense (consider a transistor in my computer). That said, I also would have intuitively predicted existing models to be quite polysemantic, and so successful monosemantic replacements of neural net circuits suggest that I'm mistaken and maybe things will be simpler than I expected for longer.
Even given highly polysemantic neurons without any simple human-understandable stories, we can still have the goal of writing algorithms we fully understand that achieve the same loss. The feasibility of that project seems like it may be very closely related to the feasibility of "factored cognition:" can we build arbitrarily complex structures in human-understandable ways out of human-understandable parts? For challenging models this is likely to involve significant caveating and clarification about what "understand" means, etc. But to the extent that those issues arise in interpretability, I think it's reasonable to cross that bridge when we come to it and that it's a reasonable domain to explore the same questions that would be important for IDA or other approaches.
(Moreover, from my perspective one of the main obstacles to making progress on factored cognition right now is that we just don't have very compelling example domains where ML systems understand important things in ways we can't. So if interpretability fails for this kind of conceptual reason, then at least we plausibly get a consolation prize of an interesting and plausibly challenging test case for other techniques.)
I'm guessing you mean in ways humans can't even in principle?
Regardless, here's something people might find amusing - researchers found that a simple VGG-like 3D CNN model can look at electron microscope images of neural tissue and do a task that humans don't know how to do. The network distinguishes neurons that specialize in certain neurotransmitters. From the abstract to this preprint:
They are developing explainability techniques to try to figure out how the CNN does this classification (see the figures in this preprint). In addition to the custom methods they've developed I know they have also used more bog-standard activation maximization techniques as well (personal communication with Jan Funke in January year). Jan told me he's read Chris Olah et al.'s publications in Distill. They think the network may be cuing in on subtle differences in the size/shape of vesicles.
Someone needs to check if we can use ML to guess activations in one set of neurons from activations in another set of neurons. The losses would give straightforward estimates of such statistical quantities as mutual information. Generating inputs that have the same activations in a set of neurons illustrates what the set of neurons does. I might do this myself if nobody else does.
I'm not clear on what you'd do with the results of that exercise. Suppose that on a certain distribution of texts you can explain 40% of the variance in half of layer 7 by using the other half of layer 7 (and the % gradually increases as you use make the activation-predicting-model bigger, perhaps you guess it's approaching 55% in the limit). What's the upshot of models being that predictable rather than more or less, or the use of the actual predictor that you learned?
Given an input x, generating other inputs that "look the same as x" to part of the model but not other parts seems like it reveals something about what that part of the model does. As a component of interpretability research that seems pretty similar to doing feature visualization or selecting input examples that activate a given neuron, and I'd guess it would fit in the same way into the project of doing interpretability.
I'd mostly be excited about people developing these techniques as part of a focused project to understand what models are thinking. I'm not really sure what to make of them in isolation.
I score such techniques on how surprised I am how well they fit together, as with all good math. In this case my evidence is: My current approach is to thoroughly analyze the likes of mutual information for modularity only on the neighborhood of one input, since that is tractable with mere linear algebra, but an activation-predicting-model is even less extra theory (since we were already working with neural nets) and just happens to produce per cross-entropy loss the same KL divergences I'm already trying to measure.
IIRC you study problem decomposition. Would your results say I'll need the same magic natural language tools that would assemble descriptions for every hierarchy node from descriptions of its children in order to construct the hierarchy in the first place? Do they say anything about how to continuously go between hierarchies as the model trains? Have you tried describing how well a hierarchy decomposes a problem by the extent to which "a: TA -> A" which maps a list of subsolutions to a solution satisfies the square
on that hierarchy?
If you can find two halves with little mutual information, you can understand one before having understood the other. I suspect that interpreting a model should be decomposed by hierarchically clustering neurons using such measurements. Since the measurement is differentiable, you can train a network for modularity to make this work better.
It sure is similar to feature visualization! I prefer it because it doesn't go out of distribution and doesn't feel like it implicitly assumes that the model implements a linear function.
I agree that interpretability is the purpose and the cure.