And we have a good idea of what signals we care about.
Seems dubious. Or, understood narrowly, is an irrelevant tautology, and the real question is which signals are important (what we should care about), which again is unclear whether we know that.
At least it'd be good to give further evidence (sorry if that is elsewhere and I missed it).
I'm not confident enough to claim the statement is either likely wrong or a tautology, but I also do not know in what sense Nina thinks we have a good idea of what signals we care about, and ask for more clarification on this point.
It's also worth noting that LLMs are not learning directly from the raw input stream but from a crux of that data (LLMs learn on compressed data) i.e. the LLMs are fed tokenized data, and the tokenizers act as compressors. This benefits the models by enabling them to have a more information-rich context.
Would you say that tokenization is part of the architecture?
And, in your wildest moments, would you say that language is also part of the architecture :)? I mean the latent space is probably mapping either a) brain states or b) world states right? Is everything between latent spaces architecture?
Note: These are my personal views, (and not of Anthropic, which I just joined today)
In this post, I will share my current model of how we should think of neural network interpretability. The content will be rather handwavy and high-level. However, I think the field could make concrete updates wrt research directions if people adopt this framing.
I’m including the original handwritten notes this is based on as well, in case the format is more intuitive to some.
Neural networks can be represented as more compressed, modular computational graphs
Compressibility
I am not claiming that for all sensible notions of “effective dimensionality,” SOTA networks have more parameters than “true effective dimensions.” However, what counts as dimensionality depends on what idealized object you look for in the mess of tensors. For many questions we want to answer via interpretability, there will be fewer dimensions than the number of parameters in the model. Ultimately, compression is about choosing some signal you care about and throwing away the rest as noise. And we have a good idea of what signals we care about.
Modularity
Adopting the analogy of binary reverse engineering, another desideratum is modularity. Why is a human-written Python file more “interpretable” than a compiled binary? The fact that the information has been transformed into text in some programming language is insufficient. For instance, look at minified and/or “uglified” javascript code - this stuff is not that interpretable. Ultimately, we want to follow the classical programmer lore of what makes good code - break stuff up into functions, don’t do too many transformations in a single function, make reusable chunks of code, build layers of abstraction but not too many, name your variables sensibly so that readers easily know what the code is doing.
We’re not in the worst-case world
In theory, interpreting neural networks could be cryptographically hard. However, due to the nature of how we train ML models, I think this will not be the case. In the worst case, if we get deceptive AIs that can hold encrypted bad programs, there is likely to be an earlier stage in training when interpretability is still feasible (see DevInterp).
But there are many reasons to predict good modularity and compressibility:
A compressed, modular representation will be easier to interpret
What does it mean to interpret a model? Why do we want to do this? I think of the goal here as gaining stronger guarantees on the behavior of some complex function. We start with some large neural net, the aforementioned bundle of inscrutable float32 tensors, and we want to figure out the general properties of the implemented function to validate its safety and robustness. Sure, one can test many inputs and see what outputs come out. However, black-box testing will not guarantee enough if the input that triggers undesirable behavior is hard to find or from a different distribution. The idea is that a compressed, modular representation will enable you to validate important properties of the network more efficiently. Perhaps even we can extract objects suitable for heuristic arguments.
An information-theoretic framing
One way to look at modern neural network architectures such as transformers is as a bunch of information-processing channels, each reading and writing to some global state information. To understand the channels, we want to know what distribution of data they operate on, what information they process and ignore, and how it is transformed. Contrast this with many current approaches to interpretability, for example, Sparse Coding, when we just take a bunch of outputs from such a channel (the intermediate activations) and try to find an optimal encoding given some prior we think sounds suitable (e.g., sparseness). However, the object we are analyzing has not been optimized to maximize information transfer between itself and a human trying to interpret it. Rather, it fits into a larger system with other components reading and post-processing the output. To figure out the optimal representation, we should consider this.
A concrete proposal
On a high level, the idea here is we:
Basic version
In a transformer, the global state information being passed around is the residual stream. We can segment the model into weight chunks, given our knowledge of the architecture - one possible segmentation is breaking down into blocks and then breaking each block into separate attention heads and the MLP layer.
Then, for each such block of weights, we examine it separately as an information processing channel that operates on the global state. In particular, we try to find a principal subset of information that this channel “cares about” so that it can be modeled as a compressed object. For instance, we can try to linearize the block using gradient information and then find a low-rank representation of the linearized channel that operates on a smaller subspace of the residual stream embeddings and writes to a smaller subspace. The usefulness of thinking in these terms is that we can try to find principal bases for the residual stream that are channel-dependent. Given the input, they maximize the extent to which the channel's output can be explained (on the data distribution).
So now we have a representation of the model as a bunch of information processing channels that only care about some subset of the global state, and we know what subsets they care about. We can then do two operations:
How does interpretability fit into AI safety?
I like to simplify AI safety research agendas into two categories - “making good stuff” and “finding bad stuff.” As banal as this sounds, asking yourself how your proposed idea relates to these is helpful.
Most interpretability research falls under the “finding bad stuff” category - we want to detect dangerous capabilities or predict bad out-of-distribution behavior when we cannot test every possible input.