Cool work! I am curious if you have in mind any 'natural kinds' which you might be able to pull out of the tensor structure? I agree that TNs seem pretty promising for interp due to their tractability, and seem fairly clearly to be in the same 'universality class broadly construed' as transformers. My concern in applying this would be that, iiuc, it doesn't natively incorporate a measure on output space, so for LLMs you'd get a disproportionate fraction of dissimilarity coming from behavioural differences on random strings (I think this is a fundamentally hard problem and you seem clearly aware of it, so no shade thrown). I have been thinking about local kernels like the eNTK (= the cosine similarity of output-gradients as a function of two samples) lately for this reason, but would be super excited if you could use TNs to trade the weight-space-local-but-linear eNTK for a global multilinear TN sorta thing? And then perhaps tensor-decomps would yield some objects with a comparable claim to semantic meaning as the eigenvectors of the eNTK?
We've found a method that tells you:
There's only one catch: you have to use a tensor network.
We've already shown that tensor-transformer variants are performant (this isn't a novel claim, see these papers for MLPs and Attention), so here we're focusing on the interpretability advances.
Linear Algebra Applies to Tensors
A tensor network is just a specific decomposition of a tensor, d a tensor is just a generalization of a matrix. This means we can apply tools from linear algebra to our entire network in a principled way. In our paper, we focus on a generalization of cosine similarity we call tensor similarity.
The most direct result is:
Let's look at our baselines:
Now to our tasks.
Backdoor Detection
So after training on SVHN (harder MNIST), we finetune more while mixing in poisoned data (ie a black diamond on the top right is now labeled "9").
We can see during training that the model learns the backdoor while predicting at the same accuracy for non-poisoned data.
On the bottom, we have a checkpoint by checkpoint graph comparing the outputs of all the poisoned data. The diagonal means "how similar is each checkpoint's output with itself", so it's trivially 100% similar (dark blue). Notice the first checkpoint is only similar to itself; this is due to the network being randomly initialized. The next block is after ~learning the task. Then the top right block is after we inserted the backdoor.
[It's important to take the time to understand this similarity heatmap because 90% of our figures are in this format]
So if we know what the poisoned data is, we can see the difference by looking at the outputs. But what if we don't know the trigger?
(Top) How similar are each checkpoint's matrices (i.e. "local sim") to each other? We see a delta of 0.15 between the blocks of checkpoints before & after finetuning on poisoned data. (Bottom) How similar is each checkpoint's output on clean data to each other? We barely see a difference.[2]
(Top) How similar is each checkpoint's tensor (ie "global sim") with each other? By accounting for symmetries (or in this case NOT accounting for anti-symmetric components that cancel out), we see a more visible difference between the poisoned checkpoints and original. (Bottom) Because we have a tensor, we can find the tensor-slice for class 9 and compare only those with each other. This would be cheating since we wouldn't know "the poisoned data relates to class 9" ahead of time, but we can compute attribution to find the parameters responsible, which would be mostly[3] class 9 in this case.
Now you have the context to see the full image:
But wait?
Matrix cosine did show a good delta, so it's not a bad baseline. However, in our later cases, it shows ~0 difference. Why is this the case?
Matrix cosine sim is local, so it will be:
When we instead contract the tensor network into its overall tensor, all of the permutations/rescalings/etc will cancel out. That said, the takeaway isn't just "use tensor sim for backdoors", but more broadly we claim:
Catastrophic Forgetting
Let's train on SVHN again, but initially only train on classes 0-4. Then mix in 5. Then 6 and so on until 9. Then we'll remove class 9 from training, getting catastrophic forgetting, and add it back in.
Do note this is the same type of plot as before where x & y axes are "checkpoint/training step".
What's interesting is that we clearly see the difference between adding 9 and removing it (control just means training more with the same data, you know, as a control). In fact, removing 9 is similar to "add 8" ie before you added 9 in the first place! Re-adding 9 also is similar to when you added 9 in the first place.
(Top Left) is the tensor sim image already explained. (Top Middle) is output similarity which only shows a bit of sim. (Top Right) is the local matrix cosine sim which doesn't show the expected structure at all. (Bottom) we do tensor sim but get the slice for each class, comparing that class's tensor with itself across checkpoints. We can clearly identify 9 as the 'forgotten digit'.
Modular Arithmetic
In modular arithmetic[4], we classically go from "memorization" to "generalization/grokking". "Frequency" here means the frequencies used by the model (which we can compute solely from the weights), specifically frequencies 0-60 with 0 on bottom row and 60 on top.
We do have an older image that gives a different angle:
The bottom is the cosine similarity of the frequencies being used. It's self similar in the first half because there are no frequencies (just memorization). What's interesting is that the tensor sim tracks the continued frequency change throughout training.
2 Layer Attention
Here the model is learning to predict n-grams better; however, all other methods don't show changes. Tensor similarity does show many changes, but we didn't explore what those differences correspond to. The important takeaway though is:
Conclusion
Tensor sim can tell us where the differences are (and we can even localize with attribution), but we still don't know what the difference is. We're excited about any future work that explores that angle, especially if you're analyzing the weights of the tensor networks to help figure out the functional change.
Overall, Tensor Networks are a solid foundation for rigorous, formal analysis. We can actually use principled techniques like cosine similarity! I highly recommend those working on finding the True Names of concepts to use tensor networks; 1-4 layer bilinear layers aren't that difficult to work with either and are performant tensor networks.
For example, this method seems perfect for Natural Abstractions/Condensation: we can directly compute if two tensors are functionally equivalent across all inputs.
If you want to learn more about this project, do read our paper, and for tensor networks: this LW post. We hope to release more educational material Soon .
This is a very tight approximation! Read our theory section for details (the relevant equation is eq 5)
There is a stripe though. This corresponds to the drop in clean-accuracy on the first checkpoint of finetuning on poisoned data.
Class 4 was also affected, I assume because it's similar to 9.
The above image may look asymetric to the untrained eye, but it's just the x & y axes being on different scales.