Architecture ready. Now train it. Diffusion training is simple in principle but requires careful implementation for good results.
Part 5 of 7 in the Diffusion Models series.
The training objective
Predict the noise that was added:
$$\mathcal{L} = \mathbb{E}{t \sim U(1,T), x_0 \sim data, \epsilon \sim \mathcal{N}(0,1)}\left[||\epsilon - \epsilon\theta(x_t, t)||^2\right]$$
Watch the training process: Training Animation
Training loop
def train_step(model, optimizer, x_0):
batch_size = x_0.shape[0]
# Sample random timesteps
t = torch.randint(0, T, (batch_size,), device=x_0.device)
# Sample noise
noise = torch.randn_like(x_0)
# Create noisy image
x_t = q_sample(x_0, t, noise)
# Predict noise
predicted_noise = model(x_t, t)
# Compute loss
loss = F.mse_loss(predicted_noise, noise)
# Update
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
Full training script
model = UNet().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)
for epoch in range(num_epochs):
for batch in dataloader:
x_0 = batch.to(device)
loss = train_step(model, optimizer, x_0)
scheduler.step()
print(f"Epoch {epoch}, Loss: {loss:.4f}")
Exponential Moving Average (EMA)
Average model weights over training. Smoother, better samples:
class EMA:
def __init__(self, model, decay=0.9999):
self.decay = decay
self.shadow = {name: param.clone()
for name, param in model.named_parameters()}
def update(self, model):
for name, param in model.named_parameters():
self.shadow[name] = (
self.decay * self.shadow[name] +
(1 - self.decay) * param.data
)
def apply(self, model):
for name, param in model.named_parameters():
param.data.copy_(self.shadow[name])
# Usage
ema = EMA(model)
for batch in dataloader:
loss = train_step(model, optimizer, batch)
ema.update(model)
# For sampling, use EMA weights
ema.apply(model)
Mixed precision training
Speed up with fp16:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
def train_step_amp(model, optimizer, x_0):
t = torch.randint(0, T, (x_0.shape[0],), device=device)
noise = torch.randn_like(x_0)
x_t = q_sample(x_0, t, noise)
with autocast():
predicted_noise = model(x_t, t)
loss = F.mse_loss(predicted_noise, noise)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
return loss.item()
2x faster, similar quality.
Loss weighting
Different timesteps contribute differently. Weight by signal-to-noise ratio:
def weighted_loss(predicted, target, t):
# Simple: weight by 1/SNR
weight = 1 / (1 - alpha_bar[t])
return (weight * (predicted - target) ** 2).mean()
Or use min-SNR weighting from recent papers.
Training tips
Learning rate: 1e-4 to 3e-4 works well
Batch size: Larger is better. Use gradient accumulation if needed.
Warmup: 1000-5000 steps of linear warmup helps
Gradient clipping: Clip to 1.0 for stability
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
Data augmentation: Random horizontal flip, slight color jitter
Monitoring training
Track:
- Loss (should decrease smoothly)
- Sample quality (generate periodically)
- FID score (if you have compute)
if step % 1000 == 0:
with torch.no_grad():
samples = sample(model, n=16)
save_grid(samples, f"samples_{step}.png")
How long to train?
CIFAR-10: ~500k steps ImageNet 256: ~1M steps High-quality: 2M+ steps
Start small, scale up.