Architecture ready. Now train it. Diffusion training is simple in principle but requires careful implementation for good results.

Part 5 of 7 in the Diffusion Models series.

The training objective

Predict the noise that was added:

$$\mathcal{L} = \mathbb{E}{t \sim U(1,T), x_0 \sim data, \epsilon \sim \mathcal{N}(0,1)}\left[||\epsilon - \epsilon\theta(x_t, t)||^2\right]$$

Training Loop

Watch the training process: Training Animation

Training loop

def train_step(model, optimizer, x_0):
    batch_size = x_0.shape[0]
    
    # Sample random timesteps
    t = torch.randint(0, T, (batch_size,), device=x_0.device)
    
    # Sample noise
    noise = torch.randn_like(x_0)
    
    # Create noisy image
    x_t = q_sample(x_0, t, noise)
    
    # Predict noise
    predicted_noise = model(x_t, t)
    
    # Compute loss
    loss = F.mse_loss(predicted_noise, noise)
    
    # Update
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss.item()

Full training script

model = UNet().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)

for epoch in range(num_epochs):
    for batch in dataloader:
        x_0 = batch.to(device)
        loss = train_step(model, optimizer, x_0)
    
    scheduler.step()
    print(f"Epoch {epoch}, Loss: {loss:.4f}")

Exponential Moving Average (EMA)

Average model weights over training. Smoother, better samples:

class EMA:
    def __init__(self, model, decay=0.9999):
        self.decay = decay
        self.shadow = {name: param.clone() 
                       for name, param in model.named_parameters()}
    
    def update(self, model):
        for name, param in model.named_parameters():
            self.shadow[name] = (
                self.decay * self.shadow[name] + 
                (1 - self.decay) * param.data
            )
    
    def apply(self, model):
        for name, param in model.named_parameters():
            param.data.copy_(self.shadow[name])

# Usage
ema = EMA(model)
for batch in dataloader:
    loss = train_step(model, optimizer, batch)
    ema.update(model)

# For sampling, use EMA weights
ema.apply(model)

Mixed precision training

Speed up with fp16:

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

def train_step_amp(model, optimizer, x_0):
    t = torch.randint(0, T, (x_0.shape[0],), device=device)
    noise = torch.randn_like(x_0)
    x_t = q_sample(x_0, t, noise)
    
    with autocast():
        predicted_noise = model(x_t, t)
        loss = F.mse_loss(predicted_noise, noise)
    
    optimizer.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    
    return loss.item()

2x faster, similar quality.

Loss weighting

Different timesteps contribute differently. Weight by signal-to-noise ratio:

def weighted_loss(predicted, target, t):
    # Simple: weight by 1/SNR
    weight = 1 / (1 - alpha_bar[t])
    return (weight * (predicted - target) ** 2).mean()

Or use min-SNR weighting from recent papers.

Training tips

Learning rate: 1e-4 to 3e-4 works well

Batch size: Larger is better. Use gradient accumulation if needed.

Warmup: 1000-5000 steps of linear warmup helps

Gradient clipping: Clip to 1.0 for stability

torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

Data augmentation: Random horizontal flip, slight color jitter

Monitoring training

Track:

  • Loss (should decrease smoothly)
  • Sample quality (generate periodically)
  • FID score (if you have compute)
if step % 1000 == 0:
    with torch.no_grad():
        samples = sample(model, n=16)
        save_grid(samples, f"samples_{step}.png")

How long to train?

CIFAR-10: ~500k steps ImageNet 256: ~1M steps High-quality: 2M+ steps

Start small, scale up.