I've never really understood why this is hard, at inference time. Attention is similar to vector search (except that the vector norm can vary). In practice, almost all norms don't vary that much, so in any particular attention head result set there are a fairly small list of locations that are attended to, and the vast majority of locations are attended to negligibly and are just noise. As vector search shows, there are smart algorithms that handle this extremely efficiently for unit norm vectors in high-dimensional spaces, and adapting one to work fairly well at somewhat variable norms just can't be that hard. Basically, put locations with different norms into different indexes, and tweak the algorithm accordingly.
Subquadratic attention just at inference time would be a big deal, even if it was still quadratic at training time — a lot of compute is spent on inference. However, while having everything fully differentiable at training time is convenient, it isn't actually necessary. What is necessary is an explore/exploit tradeoff, where you don't just look at the things you already know are important, you also randomly sample things that you currently think aren't important, preferably in an intelligent way biased towards things you think are near the cutoff, and if you used them, then you backpropagate, so that if they actually should be important, you get a chance to learn this. So basically, do carefully targeted heavy dropout. This explore/exploit tradeoff needs to be handled well: sample too little, something that you should be learning has recently become important (because your model recently learnt other things) takes longer to get noticed; explore too much, your training is slow, you waste time rechecking that unimportant things are still unimportant. But this is a soluble problem: people understand explore/exploit tradeoffs, and Search engineers understand how to sample things you current think are low relevance so you can efficiently spend expensive resources finding out whether you were right or wrong — it's a standard part of tuning a Search algorithm, there are entire textbook chapters on it.
So, why hasn't anyone built sub-quadratic attention using a tweaked variant of vector search, which is not fully differentiable, and then adding back a suitable amount of fully-differentiable exploration during training?
Or, is it the case that Google (who are very good at Search algorithms, including vector search) has built exactly that, and that's why they have million-token-plus context lengths?
Note that this proposal doesn't help with the KV cache memory, only with compute for attention (unless you combine it with some sort of smart caching). And if you did it in CPU, it would probably be slower than full quadratic attention on GPU. So this algorithm would need to run natively on GPU hardware. Which probably makes this more practicable for a company like Google that design their own tensor hardware.
A few points, none super confident.
- I like the search algorithm parallel, I haven never thought of it that way!
- Since as you said it doesn't reduce KV cache size (unless you do it on CPU), it is somewhat limited how much it can speed up inference because it will not increase batch sizes (see my answers to Alex Gibson's comment for why this is important if you don't already know).
- Unclear whether attention being efficient during training matters much because:
-- Pretraining is afaik done done at context lengths short enough for it not mattering that much that attention is quadratic.
-- Midtraining afaik takes a lot less compute than pretraining so it's probably not that important for it to be compute efficient.
-- You need to do inference when doing RL so more efficient training during RL would only help somewhat.
- Yeah, google seems to be good at efficient attention. Here is a blogpost I liked showing how good they are at long context benchmarks. I don't have takes on whether they made it subquadratic or just made it more efficient.
- Another way to make attention more feasible at long contexts is to just have more VRAM per node. Even if you don't make any architectural improvements, this just gives you more VRAM to put the KV cache in (so you can just have bigger KV caches and bigger batch sizes). Vladimir_Nesov says here that Google's TPUs are particularly good in this respect compared to Nvidia GPUs.
There are architectures which have constant memory and compute usage per token, but they are not used in practice. Ultimately I expect something like this to work; the human brain is an existence proof.
Text diffusion LLMs can be more efficient than autoregressive models in practice because it is usually more efficient on GPUs to do one big operation than many small operations in sequence, even when both require the same number of FLOPs[1].
My mental model was that autoregressive models are slow during decoding because of the memory transfers required on each token. You need to swap the entire KV cache in and out of Global Memory on each token. I imagine this is what you mean here, just wanted to check this model is along the right lines?
I feel like there must be a way to allow for more flexibility in memory usage. There are clearly scenarios where what you want is MCMC style, constant memory / markov search, i.e: recall. In cases like this, autoregressive transformers are just painful because they inherently tie up iterations with memory usage. But there are scenarios where constant memory loses as well, i.e: sparse recall tasks.
Yes, your model is correct. I wanted to make things as simple as possible when writing the blogpost but probably went too far with this one and ended up just making it confusing / partially innacurate. There are two reasons autoregressive LLM inference is inefficient at long contexts:
- You need to load the whole KV cache from VRAM at every forward pass.
- Since you need to store the whole KV cache in the VRAM for each sequence and KV caches are big, you can only store a small number of KV caches so you can only have small batch sizes. This makes inference inefficient because you have to load the weights from VRAM at every forward pass.
-- Explanation of why big batch sizes are important for making LLM inference efficient (skip if you already know): This is because GPUs have a lot more FLOPs than they have memory bandwidths. So if you multiply batch_size vectors of dimension d_model by a d_model x d_model (or d_model x d_mlp or whatever) matrix and batch size is small, you need O(d_model * d_model + batch_size * d_model) memory reads and O(batch_size * d_model * d_model) FLOPs, so this is bottlenecked by VRAM reads and most compute units just stay idle at small batch sizes, but is bottlenecked by FLOPs at big batch sizes.
I also am somewhat surprised that it's so hard to make attention more efficient.
TL;DR: In the last couple years, there have been multiple hype moments of the form "<insert paper> figured out subquadratic/linear attention, this is a game changer!" However, all the subquadratic attention mechanisms I'm aware of either are quadratic the way they are implemented in practice (with efficiency improved by only a constant factor) or underperform quadratic attention on downstream capability benchmarks.
A central issue with attention is that its FLOP complexity is quadratic in the context length (number of tokens in a sequence) and its memory complexity during inference is linear in the context length. In the last couple years, there have been multiple claims, and hype around those claims, that new architectures solved some (often all) of those problems by making alternatives to attention whose FLOP complexity is linear and/or whose memory complexity during inference is constant. These are often called subquadratic/linear attention (as opposed to regular attention which I’ll call quadratic attention). The ones I’m aware of are Kimi Linear, DeepSeek Sparse Attention (DSA), Mamba (and variants), RWKV (and variants), and text diffusion. If this were true, it would be a big deal because it would make transformer inference a lot more efficient at long contexts.
In this blogpost, I argue that they are all better thought of as “incremental improvement number 93595 to the transformer architecture” than as “subquadratic attention, a more than incremental improvement to the transformer architecture". This is because the implementations that work in practice are quadratic and only improve attention by a constant factor and subquadratic implementations underperform quadratic attention on downstream benchmarks. I think some of them are still important and impressive - for instance, Kimi Linear’s 6.3x increased inference speed at 1 million token context lengths is impressive. I just argue that they are not particularly special among incremental improvements to the transformer architecture and not game changers.
Appendix: Short explanation of how each subquadratic attention mechanism works and why it is not actually subquadratic
RWKV and Mamba
Those are entirely different mechanisms from attention that can be thought of as (much) better RNNs. They are actually subquadratic (in fact, linear) but they seem to underperform attention at frontier LLM scale, as argued for above. Mamba-attention hybrids do scale but are quadratic, as explained below for Kimi Linear.
Kimi Linear
Similar to Mamba and RWKV, Kimi Linear can be thought of as a (much) better RNN and it does actually have a linear FLOP complexity and constant memory complexity during inference. However, as said in the Kimi Linear paper, they use Kimi Linear at ¾ of layers and Multi Latent Attention (which is quadratic) on the remaining ¼ of layers. They say in the paper that when they tried using Kimi Linear on every layer, the hit to performance from doing this was too big:
And:
Thus, Kimi Linear as done in practice reduces the FLOP and memory used by the attention mechanism by a constant factor - the fraction of layers that don’t have it, in the paper’s case, ¼ (the reduction is smaller at shorter context lengths).
(Note on why the improvement in speed is 6.3x, which is bigger than 4x, at context length 1 million tokens: this is because additionally to making attention faster by a factor of almost 4x at big context length, Kimi Linear makes the KV cache smaller by a factor of almost 4x at big context length, which allows bigger batch sizes (by a factor of almost 4x), thus faster inference beyond the 4x improvement in attention FLOPs.)
DeepSeek Sparse Attention (DSA)
DSA was introduced in the DeepSeek V3.2 paper and DeepSeek V3.2, a frontier model, uses it. It works in the following way:
Thus, DSA’s FLOP complexity has two components: the lightning indexer has (up to a constant) the same complexity as regular MLA (which is quadratic) and the the subsequent MLA has linear complexity (at big context lengths) - min(context_length**2, 2048 * context_length).
So if the lightning indexer is in practice hugely cheaper than the subsequent MLA, the complexity is linear, but if it is only cheaper by a small constant factor, the complexity is still quadratic, just smaller by a small constant factor.
And the theoretical FLOP usage of the lightning indexer is only smaller by a factor of 8, so complexity is still quadratic (at least in terms of theoretical FLOP usage). Here is the calculation that leads to 8: first, n_heads * d_head of the lightning indexer is half that of n_heads * d_head of the subsequent MLA. This is not written in the paper, but can be seen by inspecting the model’s config on HuggingFace. Then, the lightning indexer only has keys and queries, no values and outputs, so that’s another factor of 2. Finally, the lightning indexer is in FP8, not FP16, which is another factor of 2.
For prefill (prompt) tokens, his calculation matches DeepSeek’s empirical findings: figure 3 in the DeepSeek V3.2 paper shows that the slope of cost (in dollars) per token as a function of position in the sequence is about 8x smaller than for MLA at big context lengths. For decoding (output) tokens, the slope is about 20x smaller, not 8x, but this is still a constant factor improvement. The improvements in per-token cost for the token at position 128k are 3.5x for prefill tokens and 9x for decoding tokens (if you look at the average token at context length 128k and not only at the last one, they go down to 3x and 7x). Note that in November 2025 (the latest date for which data is available as of writing this blogpost), OpenRouter processed 8x more prompt tokens than output tokens.
Furthermore, DSA does not reduce the KV cache size (because the 2048 tokens it attends to are different for every generated token and only known when that token is generated). This is important, because an important way in which subquadratic attention is good (for capabilities) is by increasing inference speed by reducing KV cache size which allows bigger batch sizes during inference (thus making inference cheaper) and allowing for longer context lengths by being able to have KV cache for more tokens per gigabyte of GPU memory.
Text Diffusion
Autoregressive LLMs (that is, all LLMs except for text diffusion LLMs) generate output tokens one by one in sequence, doing one forward pass per output token. A text diffusion LLM generates all the tokens at once in a single forward pass, but leaves X% of tokens blank. Then, it generates tokens in place of Y% of the blank tokens, also in a single forward pass. It repeats this a fixed number of times, after which no blank tokens remain.
Thus, while text diffusion eliminates the need for KV caches, it multiplies the FLOP usage on output tokens by a constant factor - the number of forward passes needed until no blank tokens remain.
(But wait, don’t autoregressive LLMs do one forward pass per output token, thus using more FLOPs than text diffusion models if the number of output tokens is big enough? No. Autoregressive LLMs do indeed do one forward pass per output token and thus usually do more forward passes than diffusion models. But they do each forward pass on only one token. Whereas text diffusion LLMs do each forward pass on all the output tokens at once. Thus, each forward pass of a text diffusion LLM requires as many FLOPs as all the forward passes of an autoregressive LLM combined. Text diffusion LLMs can be more efficient than autoregressive models in practice because it is usually more efficient on GPUs to do one big operation than many small operations in sequence, even when both require the same number of FLOPs[1]. However, these efficiency improvements can only happen until inference efficiency becomes bottlenecked by FLOPs.
This last sentence is oversimplified - another thing that matters here is the shapes of matrices that GPUs multiply. But this is out of the scope of this blogpost.