I was trying to understand the tensor product formulation in transformer circuits and I had basically forgotten all I ever knew about tensor products, if I ever knew anything. This very brief post is aimed at me from Wednesday 22nd when I didn't understand why that formulation of attention was true. It basically just gives a bit more background and includes a few more steps. I hope it will be helpful to someone else, too.
For understanding this, it is necessary to understand tensor products. Given two finite-dimensional vector spaces we can construct the tensor product space as the span of all matrices , where with the property . We can equivalently define it as a vector space with basis elements , where we used the basis elements of and respectively.
But not only can we define tensor products between vectors but also between linear maps that map from one vector space to the other (i.e. matrices!):
Given two linear maps (matrices) we can define , where each map simply operates on its own vector space, not interacting with the other:
For more information on the tensor product, I recommend this intuitive explanation and the Wikipedia entry.
How does this connect to the attention-only transformer?
In the "attention-only" formulation of the transformer we can write the "residual" of a fixed head as , with the values weight matrix , the attention matrix , the output weight matrix , and the current embeddings at each position
Let be the embedding dimension, the total context length and the dimension of the values, then we have that
- is an matrix,
- is a matrix,
- is a , and
- is a matrix
Let's identify the participating vector spaces:
maps from the "position" space back to the "position" space, which we will call (and which is isomorphic to ). Similarly, we have the "embedding" space and the "value" space .
It might become clear now that we can identify with an element from , i.e. that we can write .
From that lense, we can see that right-multiplying with is equivalent to multiplying with , which maps an element from to an element from , by applying to the -part of the tensor :
Identical arguments hold for and , so that we get the formulation from the paper:
Note that there is nothing special about this in terms of what these matrices represent. So it seems that a takeaway message is that whenever you have a matrix product of the form you can re-write it as (Sorry to everyone who thought that was blatantly obvious from the get-go ;P).
A previous edition of this post said that it was the space of all such matrices which is inaccurate. The span of a set of vectors/matrices is the space of all linear combinations of elements from that set. ↩︎
I'm limiting myself to finite-dim spaces because that's what is relevant to the transformer circuits paper. The actual formal definition is more general/stricter but imo doesn't add much to understanding the application in this paper ↩︎
Note that the 'linear map' that we use here is basically right multiplying with , so that it maps ↩︎
I should note that this is also what is mentioned in the paper's introduction on tensor products, but it didn't click with me, whereas going through the above steps did. ↩︎
Can't say much about transformers, but the tensor product definition seems off. There can be many elements in V⊗W that aren't expressible as v⊗w, only as a linear combination of multiple such. That can be seen from dimensionality: if v and w have dimensions n and m, then all possible pairs can only span n+m dimensions (Cartesian product), but the full tensor product has nm dimensions.
Here's an explanation of tensor products that I came up with sometime ago in an attempt to make it "click". Imagine you have a linear function that takes in two vectors and spits out a number. But wait, there are two natural but incompatible ways to imagine it:
f(a,b) + f(c,d) = f(a+c,b+d), linear in both arguments combined. The space of such functions has dimension n+m, and corresponds to Cartesian product.
f(a,b) + f(a,c) = f(a,b+c) and also f(a,c) + f(b,c) = f(a+b,c), in other words, linear in each argument separately. The space of such functions has dimension nm, and corresponds to tensor product.
It's especially simple to work through the case n=m=1. In that case all functions satisfying (1) have the form f(x,y)=ax+by, so their space is 2-dimensional, while all functions satisfying (2) have the form f(x,y)=axy, so their space is 1-dimensional. Admittedly this case is a bit funny because nm<n+m, but you can see how in higher dimensions the space of functions of type (2) becomes much bigger, because it will have terms for x1y1, x1y2, etc.
Ah yes that makes sense to me. I'll modify the post accordingly and probably write it in the basis formulation.
ETA: Fixed now, computation takes a tiny bit longer but hopefully still readable to everyone.