LSTMs helped with long-range dependencies. Attention does something more powerful - direct connections from anywhere to anywhere. Look at what matters, ignore what doesn’t.
The limitation of sequence models
RNN/LSTM process sequences step by step. To connect word 1 to word 50, information flows through 49 intermediate steps.
Attention: skip all that. Word 1 directly attends to word 50.
Core idea
Given a query, compare it against all keys. Use similarity scores to weight the values.
$$\text{Attention}(Q, K, V) = \text{softmax}(scores) \cdot V$$
Softmax ensures weights sum to 1. Output is weighted average of values.
Watch attention flow: Attention Animation
Query, Key, Value
Think of it like a dictionary lookup, but soft/fuzzy:
- Query: What am I looking for?
- Key: What does each item offer?
- Value: What information does each item have?
Query-key similarity determines how much of each value to retrieve.
Simple attention
Original sequence-to-sequence attention:
- Encoder produces states: $h_1, h_2, …, h_n$
- Decoder has current state: $s_t$
- Compare $s_t$ (query) against all $h_i$ (keys)
- Weight $h_i$ (values) by similarity
- Use weighted sum to help decode
def attention(query, keys, values):
# query: (batch, hidden)
# keys/values: (batch, seq_len, hidden)
scores = torch.matmul(query.unsqueeze(1), keys.transpose(-2, -1))
# (batch, 1, seq_len)
weights = F.softmax(scores, dim=-1)
output = torch.matmul(weights, values)
# (batch, 1, hidden)
return output.squeeze(1), weights
Scoring functions
How to compute similarity?
Dot product: $$score(q, k) = q \cdot k$$
Scaled dot product: (used in transformers) $$score(q, k) = \frac{q \cdot k}{\sqrt{d_k}}$$
Additive: (Bahdanau attention) $$score(q, k) = v^T \tanh(W_q q + W_k k)$$
Why scale?
Dot products can get large when dimensions are high. Large values → softmax becomes very peaked → gradients vanish.
Dividing by $\sqrt{d_k}$ keeps values in reasonable range.
Visualizing attention
Attention weights show what the model “looks at”:
Translating: "The cat sat on the mat"
↓ ↓ ↓ ↓ ↓ ↓
Output word "chat" attends strongly to "cat"
Output word "sur" attends strongly to "on"
Great for interpretability.
What comes next
This is basic attention: one query attends to a sequence.
Self-attention: sequence attends to itself Multi-head: multiple attention patterns in parallel
These build up to transformers.