LSTMs were a big step forward for handling sequences. But attention does something more powerful - it creates direct connections from anywhere to anywhere. Instead of information flowing through dozens of steps, word 1 can directly “look at” word 50.
The problem with sequence models
RNNs and LSTMs process sequences step by step, like reading one word at a time. To connect information from the beginning of a sentence to the end, everything has to flow through all the intermediate steps.
Think about translating a long sentence. By the time you reach the end, information about the beginning has been passed through dozens of transforms. Important details get lost in the shuffle.
Attention says: forget the middlemen. Let any position directly attend to any other position.
The core idea (it’s simpler than you think)
Attention is basically a smart lookup:
- You have a query (what you’re looking for)
- You compare it against all keys (what each item offers)
- Use those similarity scores to weight the values (actual information)
$$\text{Attention}(Q, K, V) = \text{softmax}(scores) \cdot V$$
The softmax makes the weights sum to 1, so the output is just a weighted average. Items that match your query well get more weight.
Interactive demo: Attention Animation - watch how different queries focus on different parts of the input.
Query, Key, Value - a helpful analogy
Think of it like searching a library:
- Query: “I want books about cats” (what you’re looking for)
- Key: Each book’s title and description (what each item offers)
- Value: The actual book content (what you get back)
You compare your query against all keys, then retrieve a blend of values weighted by how well they matched.
The key insight: query-key similarity determines how much of each value to retrieve.
A simple attention example
Let’s say we’re translating a sentence:
- Encoder produces a hidden state for each input word: $h_1, h_2, …, h_n$
- Decoder is currently generating a word and has state $s_t$
- Use $s_t$ as query, compare against all $h_i$ as keys
- Weight the $h_i$ values by similarity
- This weighted sum helps the decoder focus on relevant input words
def attention(query, keys, values):
# query: (batch, hidden)
# keys/values: (batch, seq_len, hidden)
scores = torch.matmul(query.unsqueeze(1), keys.transpose(-2, -1))
# (batch, 1, seq_len)
weights = F.softmax(scores, dim=-1)
output = torch.matmul(weights, values)
# (batch, 1, hidden)
return output.squeeze(1), weights
How do we compute similarity?
Several options:
Dot product (simplest): $$score(q, k) = q \cdot k$$
Scaled dot product: (used in transformers) $$score(q, k) = \frac{q \cdot k}{\sqrt{d_k}}$$
Additive: (Bahdanau attention) $$score(q, k) = v^T \tanh(W_q q + W_k k)$$
Why scale?
Dot products can get large when dimensions are high. Large values → softmax becomes very peaked → gradients vanish.
Dividing by $\sqrt{d_k}$ keeps values in reasonable range.
Visualizing attention
Attention weights show what the model “looks at”:
Translating: "The cat sat on the mat"
↓ ↓ ↓ ↓ ↓ ↓
Output word "chat" attends strongly to "cat"
Output word "sur" attends strongly to "on"
Great for interpretability.
What comes next
This is basic attention: one query attends to a sequence.
Self-attention: sequence attends to itself Multi-head: multiple attention patterns in parallel
These build up to transformers.
Attention mechanism clicked? Help others discover this by starring ML Animations and sharing on social media!