Regular attention: query from one sequence, attend to another. Self-attention: query, key, value all come from same sequence.
Every position looks at every other position. This is how transformers see relationships.
From input to Q, K, V
Take input embeddings X. Project to queries, keys, values:
$$Q = XW_Q, \quad K = XW_K, \quad V = XW_V$$
Three different views of the same sequence.
Scaled dot-product attention
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$
See queries meeting keys: Self-Attention Animation
Step by step
- Compute scores: $QK^T$ gives (seq_len × seq_len) matrix
- Scale: Divide by $\sqrt{d_k}$
- Softmax: Each row sums to 1. Row i shows how much position i attends to each other position.
- Weight values: Multiply by V. Position i’s output is weighted sum of all V’s.
Example walkthrough
Sentence: “The cat sat”
After projection:
Position 0 "The": q0, k0, v0
Position 1 "cat": q1, k1, v1
Position 2 "sat": q2, k2, v2
Attention for “sat” (position 2):
- Score with “The”: q2 · k0 = 0.5
- Score with “cat”: q2 · k1 = 2.1
- Score with “sat”: q2 · k2 = 1.3
After softmax: [0.08, 0.64, 0.28]
Output: 0.08·v0 + 0.64·v1 + 0.28·v2
“sat” attends mostly to “cat” (makes sense - cat is the subject of sitting)
Code
def self_attention(x, w_q, w_k, w_v):
# x: (batch, seq_len, d_model)
Q = x @ w_q # (batch, seq_len, d_k)
K = x @ w_k
V = x @ w_v
d_k = Q.shape[-1]
scores = Q @ K.transpose(-2, -1) / np.sqrt(d_k)
# (batch, seq_len, seq_len)
attention_weights = F.softmax(scores, dim=-1)
output = attention_weights @ V
# (batch, seq_len, d_v)
return output, attention_weights
Masking
For autoregressive models (GPT), position i shouldn’t see positions > i.
Apply mask before softmax:
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
scores = scores.masked_fill(mask, float('-inf'))
-∞ becomes 0 after softmax. Future positions ignored.
Multi-head attention
Run multiple self-attention operations in parallel, each with different projections.
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
batch, seq_len, _ = x.shape
# Project and split into heads
Q = self.w_q(x).view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.w_k(x).view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.w_v(x).view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# Attention per head
scores = (Q @ K.transpose(-2, -1)) / np.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask, float('-inf'))
attn = F.softmax(scores, dim=-1)
out = attn @ V
# Concat heads and project
out = out.transpose(1, 2).contiguous().view(batch, seq_len, -1)
return self.w_o(out)
Different heads learn different relationship patterns.