I was trying to understand the tensor product formulation in transformer circuits and I had basically forgotten all I ever knew about tensor products, if I ever knew anything. This very brief post is aimed at me from Wednesday 22nd when I didn't understand why that formulation of attention was true. It basically just gives a bit more background and includes a few more steps. I hope it will be helpful to someone else, too.
For understanding this, it is necessary to understand tensor products. Given two finite-dimensional vector spaces we can construct the tensor product space as the span of all matrices , where with the property . We can equivalently define it as a vector space with basis elements , where we used the basis elements of and respectively.
But not only can we define tensor products between vectors but also between linear maps that map from one vector space to the other (i.e. matrices!):
Given two linear maps (matrices) we can define , where each map simply operates on its own vector space, not interacting with the other:
How does this connect to the attention-only transformer?
In the "attention-only" formulation of the transformer we can write the "residual" of a fixed head as , with the values weight matrix , the attention matrix , the output weight matrix , and the current embeddings at each position
Let be the embedding dimension, the total context length and the dimension of the values, then we have that
- is an matrix,
- is a matrix,
- is a , and
- is a matrix
Let's identify the participating vector spaces:
maps from the "position" space back to the "position" space, which we will call (and which is isomorphic to ). Similarly, we have the "embedding" space and the "value" space .
It might become clear now that we can identify with an element from , i.e. that we can write .
From that lense, we can see that right-multiplying with is equivalent to multiplying with , which maps an element from to an element from , by applying to the -part of the tensor :
Identical arguments hold for and , so that we get the formulation from the paper:
Note that there is nothing special about this in terms of what these matrices represent. So it seems that a takeaway message is that whenever you have a matrix product of the form you can re-write it as (Sorry to everyone who thought that was blatantly obvious from the get-go ;P).
A previous edition of this post said that it was the space of all such matrices which is inaccurate. The span of a set of vectors/matrices is the space of all linear combinations of elements from that set. ↩︎
I'm limiting myself to finite-dim spaces because that's what is relevant to the transformer circuits paper. The actual formal definition is more general/stricter but imo doesn't add much to understanding the application in this paper ↩︎
Note that the 'linear map' that we use here is basically right multiplying with , so that it maps ↩︎