These are just some notes I wrote while reading about transformers which I thought might be a useful reference to others. Thanks to Aryan Bhatt for a correction to the attention normalization. Further corrections welcome.
Overview of Transformers
Many transformer models have the following architecture:
Data flows as follows:
- We take tokens as inputs and pass them through an embedding layer. The embedding layer outputs its result into the residual stream (x0). This has dimension (C,E), where C is the number of tokens in the context window and E is the embedding dimension.
- The residual stream is processed by the attention mechanism (H) and the result is added back into the residual stream (i.e. x1 = H(x0) + x0).
- The residual stream is processed by an MLP layer (MLP) and the result is added back into the residual stream (i.e. x2 = MLP(x1) + x1).
- Steps (2) and (3) together define a “residual block”. The body of the transformer is formed of a stack of these blocks in series.
- After the final residual block, we apply an unembedding transformation to produce logits, which represent the relative probabilities of different output tokens.
The attention mechanism (H) is divided into multiple attention heads , which act in parallel. That is,
Note that this decomposition is only useful if attention heads are non-linear. Fortunately they are! Each attention head is of the form
That is, mixes across tokens (which is the first index of ) and transforms each token in parallel. Another way we could have written this is
The matrix is also written in more common notation as , which are sometimes called the output and value weights. In general though is just some low-rank matrix that we learn. has shape (E,E) because it transforms in the embedding space.
The matrix is where the nonlinearity of attention comes in. This is given by
where is written in more common notation as , which are sometimes called the query and key weights. The dimension is the dimension of the output of , and so is the rank of . As with , is just some low-rank matrix that we learn. The softmax acts on the whole matrix.
The MLP (multilayer perceptron) layer processes the residual stream using the same MLP for each token index. That is, there is no communication between tokens in the MLP layer. All this layer does is transform in the embedding space.
A quirk of the attention mechanism is that it is covariant with respect to shuffling the token index. That is, if is a permutation matrix then
To see this, we expand the left-hand side:
The permutations don’t change any of the values inside the softmax, so they can be pulled outside:
The transpose of a permutation matrix is its inverse, so and
Similarly, the MLP layer acts on each token individually and so doesn’t know anything about their orderings.
What this means is that there is no information about token ordering in the transformer unless we put it there in the embedding space. This is what positional encodings do.
A typical positional encoding is given by adding a position-dependent vector to the embedding of each token. A common choice is
where is the token index in the context window and indexes the embedding space. Here so that this is not a periodic function of . The reason this choice is common is that there is a linear transformation for shifting , identical across all , which makes it easy for models to learn to compare adjacent tokens. If this is not apparent note that pairs of offset by give a representation of a complex number , and we can increment by multiplying by a diagonal operator which is the same for all .
Sorry for the pedantic comment, but I think you might've meant to have √dk in the denominator here.
Ah that's right. Will edit to fix.