Deep networks have unstable activations. Values explode or vanish as they pass through layers. Normalization fixes this. Layer norm is the flavor transformers use.
The problem
Without normalization, activations can:
- Grow exponentially (exploding)
- Shrink to near-zero (vanishing)
- Shift distribution layer by layer
This makes training unstable. Learning rate that works for one layer fails for another.
Layer Norm formula
For each sample, normalize across the feature dimension:
$$\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$$
Where:
- μ = mean across features
- σ² = variance across features
- γ, β = learnable scale and shift
- ε = small constant for stability
See it visualized: Layer Norm Animation
Code
def layer_norm(x, gamma, beta, eps=1e-5):
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True)
x_norm = (x - mean) / torch.sqrt(var + eps)
return gamma * x_norm + beta
Or just:
layer_norm = nn.LayerNorm(hidden_size)
output = layer_norm(x)
Layer Norm vs Batch Norm
Batch Norm:
- Normalize across batch dimension
- Needs batch statistics
- Different behavior train vs eval
- Works great for CNNs
Layer Norm:
- Normalize across feature dimension
- Each sample independent
- Same behavior train and eval
- Works great for transformers
Batch Norm: normalize each feature across batch
[sample1_feat1, sample2_feat1, ...] → normalize
Layer Norm: normalize each sample across features
[sample1_feat1, sample1_feat2, ...] → normalize
Why transformers use Layer Norm
Batch norm fails with:
- Variable sequence lengths (different pad amounts)
- Small batches
- Recurrent processing
Layer norm handles these because each sample normalized independently.
Pre-norm vs Post-norm
Post-norm (original transformer):
x = x + sublayer(x)
x = layer_norm(x)
Pre-norm (GPT-2 style):
x = x + sublayer(layer_norm(x))
Pre-norm trains more stably. Most modern models use it.
RMSNorm
Simplified version used in LLaMA:
$$\text{RMSNorm}(x) = \gamma \cdot \frac{x}{\sqrt{\frac{1}{n}\sum_i x_i^2 + \epsilon}}$$
Skip the mean subtraction. Just normalize by root mean square.
Slightly faster, works just as well usually.
def rms_norm(x, gamma, eps=1e-5):
rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
return gamma * x / rms
Where to put it
In transformer:
# Attention block
x = x + attention(layer_norm(x))
# FFN block
x = x + ffn(layer_norm(x))
Some architectures add final layer norm at the end.