Review

Overview

One of the active research areas for interpretability involves distilling neural network activations into clean, labeled features. This is made difficult because of superposition, where a neuron may fire in response to multiple, disparate signals making that neuron polysemantic. To date, research has focused on one type of such superposition: compressive superposition where a network can represent more features than it has neurons. I report on another type of superposition that can arise when a network has more neurons than features: “symmetric mixtures”. Essentially, this is a form of “favored basis” that allows a network to reinforce the magnitude of its logits via parallelism. I believe understanding this concept can help flesh out the conceptual foundations of how DNNs represent features and how we may interpret them.

Contents

  • Example of feature mixing that arose studying a toy model
  • Demonstration of how feature mixing operates to reduce loss
  • Theoretical framework for why this happens
  • Predictions for where this may arise in popular models

Toy Model

I was studying a small model on the modular addition problem, as in Nanda+23 and Chughtai+23. The former used an attention-only 1-layer transformer, while that latter used a simpler "embedded MLP" with ReLU layer. The attention mechanism of the transformer allows for products of input signals, which makes it suitable for the learned algorithm, while the ReLU layer struggles to approximate multiplication.  Meanwhile, I had also come across this note from Lee Sharkey which makes a case for use of bilinear layers. Inserting a bilinear layer into the embedded MLP model seemed a natural thing to do! The goal was a compact, interpretable model for modular addition.

After experiments with training the model, I found results that mostly aligned with expectations. The model still learns the same solution: Fourier embedding space, trigonometric angle addition, etc. Also this architecture is indeed more parameter efficient. But there remained a few interesting and confusing aspects:

  • The embedding contains two mutually prime frequencies, and it superposes the frequencies together in each of the four columns of the embedding matrix.
  • The learned bilinear layer weights only take values in roughly {-1, 1}. Rounding these values improves loss.

So why does the network learn superposed frequencies when there's enough capacity to represent them individually? And why is there no sparsity in the bilinear weights? 

We can go ahead and "unmix" the embedding into pure frequencies, by multiplying by a constructed involutory matrix. If we apply this matrix to the bilinear weights as well, we should have no effect on the network, since we've just inserted a net identity matrix. Interestingly, this introduces sparsity to the bilinear weights and they immediately become more interpretable.

You can see these results in this colab.

Mixing Mechanism

In order to better understand this phenomenon, I flipped the situation: take a constructed, sparse solution for modular addition and add mixing. I prepared a playground in this colab that goes through some examples. To summarize the results:

  • Any sparse solution can improve its loss by mixing and unmixing its features between any two layers of the solution.
  • This holds for "unbounded" losses such as cross-entropy, where increasing your logit margins continues to reduce loss.
  • The Hadamard matrices are a natural choice for mixtures. They represent a dense orthogonal basis, with elements in {-1, +1} and maximal determinant.

Theory

In order to better understand how feature mixing arises, I believe we should recast the problem of training a neural network into a different light. Consider the commonly-used cross-entropy loss function. If the model optimized with this loss function has learned something about the data such that it can perform better than chance, the loss can be reduced further by uniformly increasing the magnitude of the logits. In other words, a network has an incentive to push for larger magnitude outputs, in a way that is distinct from the signal. We can then effectively separate the problem into two different components: break the logits into a direction (signal) and magnitude. For reference, I'm partly motivated here by an analogy to approaches in physics

For simplicity, let's consider a very basic model: a 0-layer transformer. This model has two weight matrices, an embedding and unembedding, and simply learns which tokens tend to follow others. We can define the model in terms of sub-components of the original matrices as follows

where the hatted matrices represent the direction and the M's capture any magnification of logits. The magnitude portion of the solution is in some sense superfluous, as it contains no information relevant to the problem. However, as we've seen, it impacts how interpretable the network is.

Generally, these matrices are subject to some form of regularization, which constrains the values the weights take. In my experience, weight decay (L2 regularization) and layer normalization share broad characteristics: they strongly bound the magnitude of weights and encourage dense matrices. This weight restriction, applied to condition (1) above, directly sets the stage for solutions such as weighing matrices (of which Hadamard matrices are a special case).

In summary, the use of a loss function like cross-entropy, subject to regularization such as weight decay or layer normalization, are the candidate key conditions for feature mixing to occur.

Predictions and Future Work

Where might we see superposition by mixing? Let's only consider LLMs for now. LLMs are most often applied to the difficult problem of predicting a wide scope of text, so we already expect widespread compressive superposition. This may essentially supersede mixtures in most places. However, compression has been linked to sparsity, so for features that are comparatively dense, it is possible they are not getting compressed. So where do we find dense features?

Consider that part of the LLM computational graph includes the short “bigram” circuit, as discussed above. Given that most tokenizers use word parts, they need to reconstruct words from a set of low-level features of the token embedding. I would venture that this would be a good place to look for densely-represented tokens that may be subject to mixing superposition and not compressed. It's also possible that the bigram circuit then pins the representation to the mixed state, meaning mixed states propagate through the model. I'm prepared to be completely wrong on this and encourage feedback!

I also pose a separate argument that derives from this quote in the linear probes paper:

One of the important lessons is that neural networks are really about distilling computationally useful representations, and they are not about information contents as described by the field of Information Theory.

My prediction is that for complex models that reuse features across many circuits, there will be an incentive to avoid compression, as compressed features are harder to recover cleanly (see Chris Olah's note on distributed representations). This could manifest in the following way: consider finding the tokens that receive the most attention, and look at the features that are most important to the attention score--these features are more likely to be mixed instead of compressed. 

New Comment