How are activations in a transformer clustered together and what can we learn?

There has been a lot of progress using unsupervised methods (such as sparse autoencoders) to find monosemantic features in LLMs. However, is there a way that we can interpret activations without breaking them down into features?

The answer is yes. I contend that clustering activations (from a large dataset of examples) reveals strong interpretability. The process is simple. Choose any example and then find a set of activations that are L2 closest at your favorite activation stage. Semantic meaning of that stage can then be inferred by recognizing similarities in the examples that make up that set.

In some ways, cluster analysis is like the logit lens in that it focuses on what the transformer believes after each step. However cluster analysis reveals additional insight and detail.

Interactive Demo

Consider a simple case. Using a transformer trained to do 4 digit addition, we can examine particular stages (e.g. transformer lens: “pattern”, “result” , “attn_out” ) at all layers, heads and words to see the examples that are grouped tougher. This grouping can be interpreted to see how the transformer solves the problem.

To add two numbers (e.g. L + R = Sum), the transformer has to both select and sum the correct digit from L and R while also factoring in the appropriate carry.

Using clustering it is easy to see the specific steps where the transformer separately focuses on each of these tasks. 

Click here to view an interactive demo

New to LessWrong?

New Comment