These are just some notes I wrote while reading about transformers which I thought might be a useful reference to others. Thanks to Aryan Bhatt for a correction to the attention normalization. Further corrections welcome.
Many transformer models have the following architecture:
Data flows as follows:
The attention mechanism (H) is divided into multiple attention heads , which act in parallel. That is,
Note that this decomposition is only useful if attention heads are non-linear. Fortunately they are! Each attention head is of the form
That is, mixes across tokens (which is the first index of ) and transforms each token in parallel. Another way we could have written this is
The matrix is also written in more common notation as , which are sometimes called the output and value weights. In general though is just some low-rank matrix that we learn. has shape (E,E) because it transforms in the embedding space.
The matrix is where the nonlinearity of attention comes in. This is given by
where is written in more common notation as , which are sometimes called the query and key weights. The dimension is the dimension of the output of , and so is the rank of . As with , is just some low-rank matrix that we learn. The softmax acts on the whole matrix.
The MLP (multilayer perceptron) layer processes the residual stream using the same MLP for each token index. That is, there is no communication between tokens in the MLP layer. All this layer does is transform in the embedding space.
A quirk of the attention mechanism is that it is covariant with respect to shuffling the token index. That is, if is a permutation matrix then
To see this, we expand the left-hand side:
The permutations don’t change any of the values inside the softmax, so they can be pulled outside:
The transpose of a permutation matrix is its inverse, so and
Similarly, the MLP layer acts on each token individually and so doesn’t know anything about their orderings.
What this means is that there is no information about token ordering in the transformer unless we put it there in the embedding space. This is what positional encodings do.
A typical positional encoding is given by adding a position-dependent vector to the embedding of each token. A common choice is
where is the token index in the context window and indexes the embedding space. Here so that this is not a periodic function of . The reason this choice is common is that there is a linear transformation for shifting , identical across all , which makes it easy for models to learn to compare adjacent tokens. If this is not apparent note that pairs of offset by give a representation of a complex number , and we can increment by multiplying by a diagonal operator which is the same for all .