LESSWRONG
LW

Interpretability (ML & AI)AI
Frontpage

9

Positional kernels of attention heads

by Alex Gibson
10th Mar 2025
6 min read
0

9

Interpretability (ML & AI)AI
Frontpage

9

New Comment
Moderation Log
More from Alex Gibson
View more
Curated and popular this week
0Comments

Introduction:


In this post, I analyze attention heads whose attention patterns are spread out, and whose attention scores depend weakly on content. Using concentration inequalities, I argue that the softmax denominators of these heads are stable when the underlying token distribution is fixed. By sampling softmax denominators from a single "calibration text", the outputs of multiple such heads can be combined, and used to identify neurons in the first layer of GPT-2 Small which are sensitive to high-level, human interpretable properties of the surrounding text
 

Decomposition of attention patterns:

To handle LayerNorm, approximate the keys of the input at position i for head h on input x as: 

key[i](x)√dmodel=WE[xi]WK+Wpos[i]WK|WE[xi]+Wpos[i]|≈(WE[xi]WK+Wpos[n]WK|WE[xi]+Wpos[n]|)+(Wpos[i]WK−Wpos[n]WK√|Wpos[i]|2+3.52)=E[n,xi]+P[n,i]√dmodel

This approximation writes the keys as a cleanly separated sum of content and position, whereas LayerNorm mixes content and position together nonlinearly.

When these approximate keys are substituted in, the resulting approximate attention pattern is close to the true attention pattern, with a low total variation distance of about 0.05 typically. This means that typically the true attention pattern, and the attention pattern obtained when using our modified keys, differ by a shift of 5% in attention mass.

TV distance between true and reconstructed attention patterns across sequence positions for the 6 attention heads analyzed from the first layer of GPT2-Small. Results are shown for a representative text from OpenWebText. The approximation maintains low TV (typically ∼0.05) across all heads shown, with similar performance observed across all tested texts.

Adding and subtracting Wpos[n]WK here improves the approximation significantly, because |WE[xi]+Wpos[i]|≈√|Wpos[i]|2+3.52 is not nearly as precise an approximation as |WE[xi]+Wpos[i]|≈|WE[xi]+Wpos[n]| when i is near n. Adding and subtracting Wpos[n]WK keeps the absolute size of the term with √|Wpos[i]|2+3.52 in the denominator small when i is near n. Since absolute errors in attention scores correspond to relative errors in exponentiated attention scores, this keeps the relative error in exponential attention score when i is near n small.  Indeed the approximation is exact when i=n. Most of the heads studied in this post attend primarily locally, meaning that keeping the relative error small near n improves the overall approximation.

In practice, the approximate and true attention patterns end up being visually indistinguishable.

Positional kernels:


Define the positional kernel at position n as:
posn,i,query[n](x)=Softmax(query[n](x)TP[1:n+1]√dvalue)i
 

And define:

contentE[i](x),query[n](x)=equery[n](x)TE[i](x)√dvalue

 

The attention pattern attending from position n when using the approximate keys is given by:
attn_approx[n,i](x)=eattn_approxscore[n,i](x)∑nj=1eattn_approxscore[n,j](x)=posn,i,query[n](x)⋅contentE[i](x),query[n](x)∑nj=1posn,j,query[n](x)⋅contentE[j](x),query[n](x)

There are three categories of positional kernels in the first layer of GPT2-Small, shown below. These kernels are often translation equivariant, and they depend weakly on the particular query token.

Sharp positional decay

 

Slow positional decay

 

Close to uniform

 

Contextual circuit:

Stability of softmax denominators:

We have the softmax probability formula:

attn_approx[n,i](x)=posn,i,query[n](x)⋅contentE[i](x),query[n](x)∑nj=1posn,j,query[n](x)⋅contentE[j](x),query[n](x)

If we model the tokens in the sequence as drawn i.i.d according to some underlying distribution, we can apply concentration inequalities such as Hoeffding/Chebyshev to bound the probability that the softmax denominator deviates too far from its mean. In both cases, concentration is governed by ∑ni=1pos2i together with a term measuring the content dependence of the attention scores.

∑ni=1pos2i will be small when the positional kernel is spread out. So heads with wide positional kernels and attention scores depending weakly on content should have stable softmax denominators.

 

 

 

 

 

 

 

 

 

 

The sequences used are drawn either from OpenWebText or from well-known books. There is an overall input-independent decay in the denominator for many of the heads, but the stability can be seen nonetheless. The worst behaved heads are Head 6 and Head 8, which depend on local keyword density, which fluctuates, but these heads are still significantly better behaved than heads with local positional kernels.

The relative stability of the softmax denominators for each of these heads tells us that their denominators are macroscopic properties of the surrounding text. Thus, we need to inject some information about the input distribution if we want to find the "effective circuit" which the model uses with high probability. We can't work in a fully weights-based manner, in other words.

We can inject this information by picking a representative calibration text for the distribution we are interested in, and sampling softmax denominators from this text. Once these softmax denominators have been chosen, because the heads (0, 2, 6, 8, 9, 10) have very similar positional kernels, we will be able to approximate their combined output by a positionally weighted summary ∑ni=1posn,i,query[n](x)contribution[n,query[n](x),xi]. This summary will be a function of our calibrated softmax denominators, and I call this approximation the  "contextual circuit".


Contextual neurons:

Using the contextual circuit, we can efficiently identify neurons that respond to high-level contextual properties.

First, calibrate our circuit using a text with average keyword density and fix the query token to ' the' (which appears in many contexts). This one-time calibration gives us the softmax denominators we need.

Once calibrated, we can work entirely with the model weights - no more forward passes needed. Our frozen contextual circuit becomes a purely mathematical function that we can evaluate for any token.

For each of the 3072 first-layer MLP neurons, compute the "maximum token contribution" - max[j]=supt∈[0,dvoc]contribution[n,query[n](x),t]mlp[:,j], for the jth neuron. Neurons with high maximum token contributions (I use threshold 5.0, leaving about 100 neurons) are sensitive to broad contextual patterns rather than just immediate neighboring tokens.

To understand what each contextual neuron detects, we can examine which tokens contribute most and least to it through our circuit. The vast majority correspond to interpretable patterns like particular topics, languages, or sentiment, from the training set.

So, after calibrating on one text, we can discover hundreds of contextual neurons using only weight-based analysis. The neurons found fire on contexts completely unrelated to the calibration text.

Commonwealth vs American English (Neuron 704):

Top 50 token contributions:  
[' pract', ' foc', ' recogn', ' UK', ' British', ' London', ' £', ' Australia', ' Britain', 'isation', ' Australian', ' emphas', ' favour', ' Labour', ' centre', ' util', ' BBC', ' Scotland', ' behaviour', ' defence', ' Manchester', ' colour', ' €', ' labour', ' analys', ' programme', ' Liverpool', ' Wales', ' Sydney', ' Scottish', ' neighbour', ' favourite', ' organisation', ' keen', ' organis', ' offence', ' whilst', ' Melbourne', ' MPs', '£', ' honour', ' summar', ' organisations', ' Isis', ' travelling', ' Defence', ' licence', ' NHS', ' Dublin', ' armour']

Bottom 50 token contributions: 
[' program', ' mom', ' favor', ' color', ' Center', ' center', ' defense', ' toward', ' Texas', ' favorite', ' organization', ' programs', ' behavior', ' analy', 'izes', 'izations', ' neighbor', ' labor', ' Color', 'avor', ' marijuana', ' organizations', ' Defense', ' license', ' attorney', ' neighborhood', ' realize', ' GOP', ' offense', ' realized', ' Seattle', ' honor', ' recognize', ' colors', ' gotten', ' folks', ' recognized', 'color', ' armor', ' organized', ' §', ' Oregon', 'sylvania', ' baseball', ' transportation', ' Iowa', ' downtown', ' flavor', '.--', ' accomplish']

This neuron classifies text as using Commonwealth English or American English, activating on Commonwealth English. It is an example where the bottom token contributions are as important to the function of the neuron as the top token contributions. Together, these token contributions allow the model to perform a kind of Naive Bayes classification of the text.