Redwood research did very similar experiments in 2022, but didn't publish about them. They are briefly mentioned in this podcast: https://blog.redwoodresearch.org/p/the-inaugural-redwood-research-podcast.
Your bilinear attention layer is a bilinear function, but where the input (x) is copied and used as both inputs. Using those matrices Dec, L_Enc and R_enc, the way you show, is one particular way to parametrize the bilinear function. There are many other ways, the simplest of which would be to just use one tensor of shape [size-of-y, size-of-x, size-of-x]. I'm curious, why did you choose that particular parametrization?
Also, how did you initialize the model's weights? How do you initialize to prevent exploding gradients and similar problems?
I am curious about all this because my masters thesis was about a tensor network based alternative to CNNs.
A full 3rd order tensor is much larger, whereas this parametrization is the CP-decomposition form. This is the "official reason" when I'm really just building off Dooms et al. (I've never actually tried training the full tensor though!)
Re init: the init for modded gpt at that fork was kind of weird, but I'm pretty sure most standard inits prevent that. I am using RMSNorm which can be treated as a tensor network as well (I could maybe dm explanation, it's a forthcoming resource from Thomas). I'm also normalizing Q & K which isn't a tensor network, BUT compositionality is on a spectrum (maybe I am too). So this does mean a small portion of the model isn't a tensor network.
Ideally we can work around this!
The fact that tensor network architectures are scale-invariant seems underappreciated for useful steering. If my understanding is correct, it would mean that scaling the steering vector should cause the same pathways through the model to be activated, whereas without this we could be activating a totally different pathway, and get much less predictable behaviour.
Correction Below
With more thinking, I was broadly wrong here:
- If you add a steering vector, it's not just scaling, so scale invariance doesn't make a difference.
- If you scale an existing activation vector which makes up the entirety of one of the layers, the only effect would be to change the absolute magnitudes going into the softmax (since scale invariance means the relative magnitude at each position is the same). That could have some minor effect -- changing the probability distribution to be sharper or flatter, but that's all.
- If you scale some existing activation which is not an entire layer, then it's no longer scale invariant anymore either, it's kind of like adding a steering vector with zero magnitude in the other dimensions.
There is still a weak advantage for steering vectors in a tensor network because the change is going to be smooth, rather than discrete (since we're not flipping gates on and off), but basically I was just confused here, sorry about that.
I'm confused on what you're referring to. Bilinear layers are scale invariant by linearity
So x could be the input-token, a vector d (from the previous bilinear layer), or a steering vector added in, but it will still produce the same output vector (and affect the same hidden dims of the bilinear layer in the same proportions).
Another way to say this is that for:
The percentage of attribution of each weight in bilinear w/ respect to y is the same regardless of , since to compute the percentage, you'd divide by the total so that cancels out scaling by .
This also means that, solely from the weights, you can trace the computation done by injecting this steering vector.
[*Caveat: a bilinear layer computes interactions between two things. So you can compute the interaction between BOTH (1) the steering vector and itself and (2) the steering vector w/ the other directions d from previous layers. You CAN'T compute how it interacts w/ the input-token solely from the weights, because the weights don't include the input token. This is a bit of a trivial statement, but I don't want to overstate what you can get]
Overall, my main confusion w/ what you wrote is what an activation that is an entire layer or not an entire layer means.
You're correct, sorry for being confusing. Tracing through;
That's pretty much all I was trying to correct in my response. When I was talking about entire layer / not entire layer, I was just trying to say you can't pretend that adding a steering vector is actually just scaling the activation vector even if it is parallel in some dimensions. It's a trivial point I was just thinking through aloud. Like:
So basically you can ignore that, I was just slowly thinking through the maths to come to trivial conclusions.
Your claim here is different and good, and points to another useful thing about bilinear layers. As far as I can tell — you are saying you can decompose the effect of the steering vector into separable terms purely from the weights, whereas with ReLU you can't do this because you don't know which gates will flip. Neat!
Softmax attention ran much faster because I was using F.scaled_dot_product_attention() which uses a cuda kernel under the hood. How to adjust? I don't want to write my own custom CUDA kernel, so I adjusted by saying they ran just as fast per step. This isn't quite true for the reasons below
Claude Opus 4.5 can make decent triton kernels these days; I'd recommend using that if attention is a bottleneck.
Tensor networks might actually be a viable alternative to typical NNs! However, if scaling is way worse (say 50%), then I highly doubt they'll be deployed as a frontier model.
But suppose we can solve ambitious mech interp w/ tensor networks (debatable but I lean yes), then there are two regimes:
1. Low Reliability
2. High Reliability
Excited by the ambitious effort. Re: the impact argument above, I'm understanding your logic to be:
To completely replace NNs in frontier applications, scaling of TNs needs to be on par. Therefor, if we needed to replace them for TNs to be useful, we should test the scaling laws first. However, in a world where scaling is worse, TNs can still be useful by allowing for "ambitious mech interp" which would result in a High Reliability model. These two regimes aren't mutually exclusive.
Am I following the argument correctly?
Yep! But I do think the highest priority thing would be actually doing ambitious interp w/ this, although, if we had 100 people working on this (instead of ~4-5 full time?), a few working on the scaling laws would be good.
TNs are more amenable to optimizing exactly what we want in a mathematically precise way, so optimizing for this (to achieve ambitious mech interp) would incur an additional cost in capabilities, just fyi.
Disclaimer: I'm not very familiar with the ins and outs of tensor networks (so thanks for the reading list :D).
I would think that a composable dual encoder architecture would be able to hold more information than an MLP one, so it seems counterintuitive that the dual encoder requires more steps to achieve the same cross-entropy. I'm sure this is in part due to the more complex loss function, so maybe there is some threshold on dataset size or model size above which the tensor variant achieves lower CE?
Just looking at Shazeer's paper (Appendix A)
All of the GLU models performed better (lower is better) and the GLU models have a bilinear encoder (just w/ & w/o a sigmoid/GeLU/Swish/ReLU function). So in fact it does better (if this is what you meant by a dual encoder).
HOWEVER, we could have 3 encoders, or 100! This should store even more information, and would probably perform better per step, but would take up more GPU VRAM and/or take longer to compute each step.
In this post, though, I used wall clock time as a measure of training efficiency. Hand-wavy:
loss/step * time/step
(maybe it should be divided to make it loss/time?)
Ah, that makes more sense, thanks!
Also I agree with using loss/time as the measure of performance, since it's fairly straightforward to interpret (loss recovered per unit time). If I were reviewing this, I'd look for that.
For efficiency in practice, I think most ML papers look at FLOP/s since it is hardware agnostic. Maybe a good measure of efficiency here would be loss per FLOP per second? I haven't seen that used, but it might reflect how performance scales with computational speed.
Edit: Actually thinking about it, the test-time efficiency might be a better comparison, assuming the two scale within roughly the same complexity class. I think from a product perspective, speed for users is super (maybe the most) valuable.
At CE threshold chosen, capacity wouldn't be the bottleneck.
Besides, MLP networks can store information as efficiently as the number of parameters they have.
I've been researching tensor networks as a more interpretable architecture, but whenever I tell people this, they always ask "But is it any good?"
So I trained multiple 500M parameter LLMs on fineweb, showing the tensor variant needed ~4% more batches of data to match CE-loss.
There's a few caveats, so my personal estimate is around 15% worst to 10% better. Details below.
The Architecture
Replacing MLP w/ a Bilinear Layer
An MLP is a linear encoder, ReLU, then linear decoder.
MLP(x)=D(ReLU(E(x)))A bilinear layer asks "what's better than one encoder? Two!"
Bilinear(x)=D(Lx⊙Rx)Where ⊙ means "element-wise multiply" eg
[1,2,3]⊙[1,2,3]=[1,4,9]A SwiGLU Layer (Swish Gated Linear Unit) says "Let's add in nonlinearities"
SwiGLU(x)=D(swish(Lx)⊙Rx)SwiGLU is a SOTA architecture & Bilinear is a tensor network.
Replacing Softmax Attn w/ Bilinear Attn
For a tensor network, we are only allowed polynomial nonlinearities. For attention, this means we need to replace softmax w/ a polynomial. The simplest is an element-wise squaring of the attention pattern with itself.
Attn2=(xQK⊤x⊤)2=(xQK⊤x⊤)⊙(xQK⊤x⊤)Notice how this is like computing the same attention pattern twice, then element-wise multiplying.
We can instead be slightly more expressive by training two sets of Q & K:
Bilinear_Attn=(xQ1K⊤1x⊤)⊙(xQ2K⊤2x⊤)Experiment & Results
I forked an older commit of modded-nanoGPT and trained four ~500M parameter LLMs (~GPT-medium size) on fineweb, switching out Bilinear & SwiGLU for the MLP-component and Softmax Attn & Bilinear attn for Attention. I trained on CE loss, comparing when they reached a loss of 3.0.
Here, using the Bilinear or SwiGLU layer were basically the same, but switching to bilinear attn came at cost (but a very small one. Though note the 10% more parameters for Bi_attn)
I initially ran experiments in the GPT-2 small range:
Which were much worse for the tensor-variants.
One might think "Oh, the tensor-variants actually scale better?", but I think it's because I forked from modded-nanogpt who's hyperparams & architecture is overfitted to the GPT-2 size model for Softmax attention. You'd need to do a larger hyperparam sweep (& various architecture changes, yikes) to get a better idea!
Caveats:
(1) Softmax attention ran faster cause it has a CUDA kernel
Softmax attention ran much faster because I was using F.scaled_dot_product_attention() which uses a cuda kernel under the hood. How to adjust? I don't want to write my own custom CUDA kernel, so I adjusted by saying they ran just as fast per step. This isn't quite true for the reasons below
(2) Bilinear Attention can run much faster than Softmax Attn
Bilinear Attention is O(seq⋅d3h) vs O(seq2⋅dh) for normal softmax, where dh := the head_dimension (usually d_model/num of heads)
So Bilinear attention is more efficient when:
seq>d2hFor a seq length of 1M & dh of 128 (from deepseek-v3):
1e6>1.6e4In other words, bilinear attention is more efficient computationally in this case when sequence length is > 1,600.
[More details & proof in appendix B]
As a quick aside, we're gaining on computational efficiency, but this does come at a cost of less expressivity (see Appendix C)
But what about Flash Attention?
See appendix D
(3) Bilinear Attention has more Parameters
It's not fair that bilinear attention has twice the number of Q & K matrices, so I tried a baseline of differential attention from the literature which claims to need ~38% fewer parameters or ~36% fewer tokens. But it performed very poorly in my case(ie 100% worse)! It could've been a scaling issue, hyper-parameters, or coding bug (but the implementation is simple, see my code here).
(4) This was the 2nd-Dumbest Tensor-Attn Variant
There are likely way more efficient tensor-attn variants that exist. The 2nd dumbest is this bilinear attention, where the dumbest is just the .square() (which is like bilinear attention, but w/ tied Q's & tied K's weights).
Overall, I'm thinking this tensor attention variant is 15% worse to 10% better than softmax attention.
Replication & Trained Models
Code here, majority of code is just train_gpt2.py
Trained models here.[1]
Future Work
There's many more experiments one could run (eg scaling laws), but I'm currently focusing on actually doing interp w/ these models.
(Also, I've already spent ~$500 on 8 H100's for all the experiments. Speaking of which, special thanks to Principles of Intelligence (formerly PIBBSS) for funding me at the time of this research!)
Path to Impact
Tensor networks might actually be a viable alternative to typical NNs! However, if scaling is way worse (say 50%), then I highly doubt they'll be deployed as a frontier model.
But suppose we can solve ambitious mech interp w/ tensor networks (debatable but I lean yes), then there are two regimes:
1. Low Reliability
2. High Reliability
For writing math proofs we can verify, it's fine to have low reliability because failure is cheap. For self-driving cars though, you really want high reliability!
So we can do distillation or just train smaller tensor networks that aren't as generally capable, but are task specific.
Having highly robust, task-specific AI sounds great. So great, it might actually make a lot of money for various tasks.
This could change the financial incentives away from investing in low reliability AGI and towards highly reliable task-AI.
Interp w/ Tensor Networks
The second question I get when I bring up Tensor Networks is how they're actually more interpretable.
Most work on tensor networks isn't concerned with both ambitious interp and viability; however, Thomas Dooms, Ward Gauderis et al have been cooking this year!
Bilinear Autoencoders - they find structure in models using a bilinear AEs. See example manifolds here.
Compositionality Unlocks Deep Interpretable Models - a stack of bilinear layers is performant, while enabling us to view the global structure across the whole tensor network (since you can compose them together), though be warned, lots of math and kind of hard to understand.
Bilinear MLPS Enable Weight-Based Mech Interp - "Bilinear MLPs can be fully expressed in terms of linear operations using a third-order tensor, allowing flexible analysis of the weights. Analyzing the spectra of bilinear MLP weights using eigendecomposition reveals interpretable low-rank structure across toy tasks, image classification, and language modeling. We use this understanding to craft adversarial examples, uncover overfitting, and identify small language model circuits directly from the weights alone."
[And if I understand correctly, Transluce linearized their LLM per datapoint (making it a tensor network, but, again, only for that datapoint) to improve attribution.]
But I really believe this is just the tip of the iceberg. Tensor networks have a lot of useful properties that make them amenable to analytic tools. In short:
As well as some forthcoming work from them that I'm (very) excited to see released!
As for me, I'm currently working on circuits applied to tensor networks. Do feel free to reach out to me here or on discord ( # loganriggs) if you're interested in this research direction!
Special thanks to Thomas Dooms for reviewing an earlier draft of this post.
Appendix A: Noam Shazeer's 2020 paper:
In Noam Shazeer's 2020 paper, he trains these different architectures on a span filling task on the C4 dataset, showing their log-perplexity loss.
Where the Bilinear layer does quite well! Even beating GLU (which should be called SiGLU for sigmoid GLU).
Appendix B: Scaling of Bilinear Attention
Proof from Claude Opus 4.5, edited & reviewed by me, w/ code verification here.
Our bilinear form is:
(xQ1K⊤1x⊤)seq×seq⊙(xQ2K⊤2x⊤)seq×seq⋅vwhere ⊙ is the elementwise product. Naively this requires materializing a seq×seq matrix, which is O(seq2).
Defining:
q1=xQ1∈Rseq×dh,k1=xK1∈Rseq×dhq2=xQ2∈Rseq×dh,k2=xK2∈Rseq×dhwhere dh is the hidden dimension of the attention head. So the attention patterns are:
A1=q1k⊤1∈Rseq×seq,A2=q2k⊤2∈Rseq×seqThe Row-wise Khatri-Rao Product, ⊛, is defined row-by-row as the Kronecker (outer) product of each row:
(A⊛B)i,:=Ai,:⊗Bi,:For example:
Ai,:=[a1,a2] and Bi,:=[b1,b2,b3], then:
(A⊛B)i,:=[a1b1,a1b2,a1b3,a2b1,a2b2,a2b3]A key identity[2] is
(AB⊤)⊙(CD⊤)=(A⊛C)(B⊛D)⊤Using this identity:
A1⊙A2=(q1k⊤1)⊙(q2k⊤2)=(q1⊛q2)(k1⊛k2)⊤Define:
~Q=q1⊛q2∈Rseq×d2h~K=k1⊛k2∈Rseq×d2hSo to compute ~Q we're taking the row-wise outer product of the hidden dimensions of q1 & q2, repeating this for every sequence position (later, we'll mix across sequences when combining ~Q & ~K).
Now the output becomes:
(A1⊙A2)⋅v=~Q~K⊤vEverything here is just matrices/vectors being matrix-multiplied, so we have two choices on how to combine them:
1. Left to Right (Inefficient): (~Q~K⊤)v
So we're scaling quadratically in both seq-size & dh
2. Right-to-Left (Efficient): ~Q(~K⊤v)
So we're scaling linearly in seq-size & cubically in dh
Appendix C: Bilinear Attention Expressivity
Softmax attention is more expressive. The max rank of the attention matrix for softmax is seq (full rank) whereas for bilinear attention, we're stuck at d2k. There's no free lunch, so while we're gaining a lot in computational efficiency, we're losing in expressivity.
In the kernel view, softmax approximates a Gaussian kernel, meaning it can approximate any continuous function, while bilinear attention is just a degree-2 polynomial kernel.
Appendix D: But what about Flash Attention?
Flash attention is about memory, not computation. It's purpose is to avoid materializing the full seq x seq matrix by splitting it up into tiles. You can't compute softmax over tiles since it's a property of entire rows, so they compute statistics on each tile which can be combined to ~compute softmax.
Tensor-Attention variants don't use softmax, so you don't need to do any clever tricks there (although the efficient bilinear attention method probably requires larger memory? I'll leave this as an exercise to the reader).
The naming scheme goes as follows:
eg Elriggs/gpt2-swiglu-18l-9h-1152embd
is read sensibly as gpt2 w/ swiglu, 18 layers, 9 attn heads, & 1152 embedding dim. If attention isn't mentioned, it's assumed to be softmax.
The rest of my naming scheme ended up being mixed up & non-ideal. I recommend going to each one's config.json if you're looking for a particular run.
Section 3, fourth equation down, citation is [16]