LSTMs were a big step forward for handling sequences. But attention does something more powerful - it creates direct connections from anywhere to anywhere. Instead of information flowing through dozens of steps, word 1 can directly “look at” word 50.

The problem with sequence models

RNNs and LSTMs process sequences step by step, like reading one word at a time. To connect information from the beginning of a sentence to the end, everything has to flow through all the intermediate steps.

Think about translating a long sentence. By the time you reach the end, information about the beginning has been passed through dozens of transforms. Important details get lost in the shuffle.

Attention says: forget the middlemen. Let any position directly attend to any other position.

The core idea (it’s simpler than you think)

Attention is basically a smart lookup:

  1. You have a query (what you’re looking for)
  2. You compare it against all keys (what each item offers)
  3. Use those similarity scores to weight the values (actual information)

$$\text{Attention}(Q, K, V) = \text{softmax}(scores) \cdot V$$

The softmax makes the weights sum to 1, so the output is just a weighted average. Items that match your query well get more weight.

Interactive demo: Attention Animation - watch how different queries focus on different parts of the input.

Query, Key, Value - a helpful analogy

Think of it like searching a library:

  • Query: “I want books about cats” (what you’re looking for)
  • Key: Each book’s title and description (what each item offers)
  • Value: The actual book content (what you get back)

You compare your query against all keys, then retrieve a blend of values weighted by how well they matched.

The key insight: query-key similarity determines how much of each value to retrieve.

A simple attention example

Let’s say we’re translating a sentence:

  1. Encoder produces a hidden state for each input word: $h_1, h_2, …, h_n$
  2. Decoder is currently generating a word and has state $s_t$
  3. Use $s_t$ as query, compare against all $h_i$ as keys
  4. Weight the $h_i$ values by similarity
  5. This weighted sum helps the decoder focus on relevant input words
def attention(query, keys, values):
    # query: (batch, hidden)
    # keys/values: (batch, seq_len, hidden)
    
    scores = torch.matmul(query.unsqueeze(1), keys.transpose(-2, -1))
    # (batch, 1, seq_len)
    
    weights = F.softmax(scores, dim=-1)
    output = torch.matmul(weights, values)
    # (batch, 1, hidden)
    
    return output.squeeze(1), weights

How do we compute similarity?

Several options:

Dot product (simplest): $$score(q, k) = q \cdot k$$

Scaled dot product: (used in transformers) $$score(q, k) = \frac{q \cdot k}{\sqrt{d_k}}$$

Additive: (Bahdanau attention) $$score(q, k) = v^T \tanh(W_q q + W_k k)$$

Why scale?

Dot products can get large when dimensions are high. Large values → softmax becomes very peaked → gradients vanish.

Dividing by $\sqrt{d_k}$ keeps values in reasonable range.

Visualizing attention

Attention weights show what the model “looks at”:

Translating: "The cat sat on the mat"
             ↓   ↓    ↓   ↓   ↓   ↓
Output word "chat" attends strongly to "cat"
Output word "sur"  attends strongly to "on"

Great for interpretability.

What comes next

This is basic attention: one query attends to a sequence.

Self-attention: sequence attends to itself Multi-head: multiple attention patterns in parallel

These build up to transformers.


Attention mechanism clicked? Help others discover this by starring ML Animations and sharing on social media!