Deep networks have unstable activations. Values explode or vanish as they pass through layers. Normalization fixes this. Layer norm is the flavor transformers use.

The problem

Without normalization, activations can:

  • Grow exponentially (exploding)
  • Shrink to near-zero (vanishing)
  • Shift distribution layer by layer

This makes training unstable. Learning rate that works for one layer fails for another.

Layer Norm formula

For each sample, normalize across the feature dimension:

$$\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$$

Where:

  • μ = mean across features
  • σ² = variance across features
  • γ, β = learnable scale and shift
  • ε = small constant for stability

Layer Normalization

See it visualized: Layer Norm Animation

Code

def layer_norm(x, gamma, beta, eps=1e-5):
    mean = x.mean(dim=-1, keepdim=True)
    var = x.var(dim=-1, keepdim=True)
    x_norm = (x - mean) / torch.sqrt(var + eps)
    return gamma * x_norm + beta

Or just:

layer_norm = nn.LayerNorm(hidden_size)
output = layer_norm(x)

Layer Norm vs Batch Norm

Batch Norm:

  • Normalize across batch dimension
  • Needs batch statistics
  • Different behavior train vs eval
  • Works great for CNNs

Layer Norm:

  • Normalize across feature dimension
  • Each sample independent
  • Same behavior train and eval
  • Works great for transformers
Batch Norm: normalize each feature across batch
           [sample1_feat1, sample2_feat1, ...] → normalize

Layer Norm: normalize each sample across features
           [sample1_feat1, sample1_feat2, ...] → normalize

Why transformers use Layer Norm

Batch norm fails with:

  • Variable sequence lengths (different pad amounts)
  • Small batches
  • Recurrent processing

Layer norm handles these because each sample normalized independently.

Pre-norm vs Post-norm

Post-norm (original transformer):

x = x + sublayer(x)
x = layer_norm(x)

Pre-norm (GPT-2 style):

x = x + sublayer(layer_norm(x))

Pre-norm trains more stably. Most modern models use it.

RMSNorm

Simplified version used in LLaMA:

$$\text{RMSNorm}(x) = \gamma \cdot \frac{x}{\sqrt{\frac{1}{n}\sum_i x_i^2 + \epsilon}}$$

Skip the mean subtraction. Just normalize by root mean square.

Slightly faster, works just as well usually.

def rms_norm(x, gamma, eps=1e-5):
    rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
    return gamma * x / rms

Where to put it

In transformer:

# Attention block
x = x + attention(layer_norm(x))

# FFN block  
x = x + ffn(layer_norm(x))

Some architectures add final layer norm at the end.