Your bilinear attention layer is a bilinear function, but where the input (x) is copied and used as both inputs. Using those matrices Dec, L_Enc and R_enc, the way you show, is one particular way to parametrize the bilinear function. There are many other ways, the simplest of which would be to just use one tensor of shape [size-of-y, size-of-x, size-of-x]. I'm curious, why did you choose that particular parametrization?
Also, how did you initialize the model's weights? How do you initialize to prevent exploding gradients and similar problems?
I am curious about all this because my masters thesis was about a tensor network based alternative to CNNs.
Disclaimer: I'm not very familiar with the ins and outs of tensor networks (so thanks for the reading list :D).
I would think that a composable dual encoder architecture would be able to hold more information than an MLP one, so it seems counterintuitive that the dual encoder requires more steps to achieve the same cross-entropy. I'm sure this is in part due to the more complex loss function, so maybe there is some threshold on dataset size or model size above which the tensor variant achieves lower CE?
Softmax attention ran much faster because I was using F.scaled_dot_product_attention() which uses a cuda kernel under the hood. How to adjust? I don't want to write my own custom CUDA kernel, so I adjusted by saying they ran just as fast per step. This isn't quite true for the reasons below
Claude Opus 4.5 can make decent triton kernels these days; I'd recommend using that if attention is a bottleneck.
I've been researching tensor networks as a more interpretable architecture, but whenever I tell people this, they always ask "But is it any good?"
So I trained multiple 500M parameter LLMs on fineweb, showing the tensor variant needed ~4% more batches of data to match CE-loss.
There's a few caveats, so my personal estimate is around 15% worst to 10% better. Details below.
The Architecture
Replacing MLP w/ a Bilinear Layer
An MLP is a linear encoder, ReLU, then linear decoder.
MLP(x)=D(ReLU(E(x)))A bilinear layer asks "what's better than one encoder? Two!"
Bilinear(x)=D(Lx⊙Rx)Where ⊙ means "element-wise multiply" eg
[1,2,3]⊙[1,2,3]=[1,4,9]A SwiGLU Layer (Swish Gated Linear Unit) says "Let's add in nonlinearities"
SwiGLU(x)=D(swish(Lx)⊙Rx)SwiGLU is a SOTA architecture & Bilinear is a tensor network.
Replacing Softmax Attn w/ Bilinear Attn
For a tensor network, we are only allowed polynomial nonlinearities. For attention, this means we need to replace softmax w/ a polynomial. The simplest is an element-wise squaring of the attention pattern with itself.
Attn2=(xQK⊤x⊤)2=(xQK⊤x⊤)⊙(xQK⊤x⊤)Notice how this is like computing the same attention pattern twice, then element-wise multiplying.
We can instead be slightly more expressive by training two sets of Q & K:
Bilinear_Attn=(xQ1K⊤1x⊤)⊙(xQ2K⊤2x⊤)Experiment & Results
I forked an older commit of modded-nanoGPT and trained four ~500M parameter LLMs (~GPT-medium size) on fineweb, switching out Bilinear & SwiGLU for the MLP-component and Softmax Attn & Bilinear attn for Attention. I trained on CE loss, comparing when they reached a loss of 3.0.
Here, using the Bilinear or SwiGLU layer were basically the same, but switching to bilinear attn came at cost (but a very small one. Though note the 10% more parameters for Bi_attn)
I initially ran experiments in the GPT-2 small range:
Which were much worse for the tensor-variants.
One might think "Oh, the tensor-variants actually scale better?", but I think it's because I forked from modded-nanogpt who's hyperparams & architecture is overfitted to the GPT-2 size model for Softmax attention. You'd need to do a larger hyperparam sweep (& various architecture changes, yikes) to get a better idea!
Caveats:
(1) Softmax attention ran faster cause it has a CUDA kernel
Softmax attention ran much faster because I was using F.scaled_dot_product_attention() which uses a cuda kernel under the hood. How to adjust? I don't want to write my own custom CUDA kernel, so I adjusted by saying they ran just as fast per step. This isn't quite true for the reasons below
(2) Bilinear Attention can run much faster than Softmax Attn
Bilinear Attention is O(seq⋅d3h) vs O(seq2⋅dh) for normal softmax, where dh := the head_dimension (usually d_model/num of heads)
So Bilinear attention is more efficient when:
seq>d2hFor a seq length of 1M & dh of 128 (from deepseek-v3):
1e6>1.6e4In other words, bilinear attention is more efficient computationally in this case when sequence length is > 1,600.
[More details & proof in appendix B]
As a quick aside, we're gaining on computational efficiency, but this does come at a cost of less expressivity (see Appendix C)
But what about Flash Attention?
See appendix D
(3) Bilinear Attention has more Parameters
It's not fair that bilinear attention has twice the number of Q & K matrices, so I tried a baseline of differential attention from the literature which claims to need ~38% fewer parameters or ~36% fewer tokens. But it performed very poorly in my case(ie 100% worse)! It could've been a scaling issue, hyper-parameters, or coding bug (but the implementation is simple, see my code here).
(4) This was the 2nd-Dumbest Tensor-Attn Variant
There are likely way more efficient tensor-attn variants that exist. The 2nd dumbest is this bilinear attention, where the dumbest is just the .square() (which is like bilinear attention, but w/ tied Q's & tied K's weights).
Overall, I'm thinking this tensor attention variant is 15% worse to 10% better than softmax attention.
Replication & Trained Models
Code here, majority of code is just train_gpt2.py
Trained models here.[1]
Future Work
There's many more experiments one could run (eg scaling laws), but I'm currently focusing on actually doing interp w/ these models.
(Also, I've already spent ~$500 on 8 H100's for all the experiments. Speaking of which, special thanks to Principles of Intelligence (formerly PIBBSS) for funding me at the time of this research!)
Path to Impact
Tensor networks might actually be a viable alternative to typical NNs! However, if scaling is way worse (say 50%), then I highly doubt they'll be deployed as a frontier model.
But suppose we can solve ambitious mech interp w/ tensor networks (debatable but I lean yes), then there are two regimes:
1. Low Reliability
2. High Reliability
For writing math proofs we can verify, it's fine to have low reliability because failure is cheap. For self-driving cars though, you really want high reliability!
So we can do distillation or just train smaller tensor networks that aren't as generally capable, but are task specific.
Having highly robust, task-specific AI sounds great. So great, it might actually make a lot of money for various tasks.
This could change the financial incentives away from investing in low reliability AGI and towards highly reliable task-AI.
Interp w/ Tensor Networks
The second question I get when I bring up Tensor Networks is how they're actually more interpretable.
Most work on tensor networks isn't concerned with both ambitious interp and viability; however, Thomas Dooms, Ward Gauderis et al have been cooking this year!
Bilinear Autoencoders - they find structure in models using a bilinear AEs. See example manifolds here.
Compositionality Unlocks Deep Interpretable Models - a stack of bilinear layers is performant, while enabling us to view the global structure across the whole tensor network (since you can compose them together), though be warned, lots of math and kind of hard to understand.
Bilinear MLPS Enable Weight-Based Mech Interp - "Bilinear MLPs can be fully expressed in terms of linear operations using a third-order tensor, allowing flexible analysis of the weights. Analyzing the spectra of bilinear MLP weights using eigendecomposition reveals interpretable low-rank structure across toy tasks, image classification, and language modeling. We use this understanding to craft adversarial examples, uncover overfitting, and identify small language model circuits directly from the weights alone."
[And if I understand correctly, Transluce linearized their LLM per datapoint (making it a tensor network, but, again, only for that datapoint) to improve attribution.]
But I really believe this is just the tip of the iceberg. Tensor networks have a lot of useful properties that make them amenable to analytic tools. In short:
As well as some forthcoming work from them that I'm (very) excited to see released!
As for me, I'm currently working on circuits applied to tensor networks. Do feel free to reach out to me here or on discord ( # loganriggs) if you're interested in this research direction!
Special thanks to Thomas Dooms for reviewing an earlier draft of this post.
Appendix A: Noam Shazeer's 2020 paper:
In Noam Shazeer's 2020 paper, he trains these different architectures on a span filling task on the C4 dataset, showing their log-perplexity loss.
Where the Bilinear layer does quite well! Even beating GLU (which should be called SiGLU for sigmoid GLU).
Appendix B: Scaling of Bilinear Attention
Proof from Claude Opus 4.5, edited & reviewed by me, w/ code verification here.
Our bilinear form is:
(xQ1K⊤1x⊤)seq×seq⊙(xQ2K⊤2x⊤)seq×seq⋅vwhere ⊙ is the elementwise product. Naively this requires materializing a seq×seq matrix, which is O(seq2).
Defining:
q1=xQ1∈Rseq×dh,k1=xK1∈Rseq×dhq2=xQ2∈Rseq×dh,k2=xK2∈Rseq×dhwhere dh is the hidden dimension of the attention head. So the attention patterns are:
A1=q1k⊤1∈Rseq×seq,A2=q2k⊤2∈Rseq×seqThe Row-wise Khatri-Rao Product, ⊛, is defined row-by-row as the Kronecker (outer) product of each row:
(A⊛B)i,:=Ai,:⊗Bi,:For example:
Ai,:=[a1,a2] and Bi,:=[b1,b2,b3], then:
(A⊛B)i,:=[a1b1,a1b2,a1b3,a2b1,a2b2,a2b3]A key identity[2] is
(AB⊤)⊙(CD⊤)=(A⊛C)(B⊛D)⊤Using this identity:
A1⊙A2=(q1k⊤1)⊙(q2k⊤2)=(q1⊛q2)(k1⊛k2)⊤Define:
~Q=q1⊛q2∈Rseq×d2h~K=k1⊛k2∈Rseq×d2hSo to compute ~Q we're taking the row-wise outer product of the hidden dimensions of q1 & q2, repeating this for every sequence position (later, we'll mix across sequences when combining ~Q & ~K).
Now the output becomes:
(A1⊙A2)⋅v=~Q~K⊤vEverything here is just matrices/vectors being matrix-multiplied, so we have two choices on how to combine them:
1. Left to Right (Inefficient): (~Q~K⊤)v
So we're scaling quadratically in both seq-size & dh
2. Right-to-Left (Efficient): ~Q(~K⊤v)
So we're scaling linearly in seq-size & cubically in dh
Appendix C: Bilinear Attention Expressivity
Softmax attention is more expressive. The max rank of the attention matrix for softmax is seq (full rank) whereas for bilinear attention, we're stuck at d2k. There's no free lunch, so while we're gaining a lot in computational efficiency, we're losing in expressivity.
In the kernel view, softmax approximates a Gaussian kernel, meaning it can approximate any continuous function, while bilinear attention is just a degree-2 polynomial kernel.
Appendix D: But what about Flash Attention?
Flash attention is about memory, not computation. It's purpose is to avoid materializing the full seq x seq matrix by splitting it up into tiles. You can't compute softmax over tiles since it's a property of entire rows, so they compute statistics on each tile which can be combined to ~compute softmax.
Tensor-Attention variants don't use softmax, so you don't need to do any clever tricks there (although the efficient bilinear attention method probably requires larger memory? I'll leave this as an exercise to the reader).
The naming scheme goes as follows:
eg Elriggs/gpt2-swiglu-18l-9h-1152embd
is read sensibly as gpt2 w/ swiglu, 18 layers, 9 attn heads, & 1152 embedding dim. If attention isn't mentioned, it's assumed to be softmax.
The rest of my naming scheme ended up being mixed up & non-ideal. I recommend going to each one's config.json if you're looking for a particular run.
Section 3, fourth equation down, citation is [16]