How does a neural network learn? It adjusts its weights to make better predictions. But how does it know which direction to adjust? That’s where gradient descent comes in - it’s the core algorithm powering all of deep learning.
The basic idea
Imagine you’re blindfolded on a hilly landscape and need to find the lowest point. You can feel the slope under your feet. Gradient descent says: always step downhill.
In math terms:
- You have a loss function L(w) measuring how wrong your predictions are
- The gradient ∇L tells you which direction makes the loss go UP
- So you step in the opposite direction to make it go DOWN
$$w_{new} = w_{old} - \alpha \nabla L(w)$$
Where α (alpha) is the learning rate - how big your steps are.
Interactive demo: Gradient Descent Animation - see the optimization ball roll downhill.
Why this actually works
Think of the loss as a landscape:
- Mountains = bad predictions (high loss)
- Valleys = good predictions (low loss)
The gradient is like feeling the slope - it tells you which direction is steepest uphill. Go the opposite way, and you head downhill toward better predictions.
Keep stepping downhill until you reach a valley. That’s your trained model!
Learning rate: the most important hyperparameter
The learning rate controls your step size, and getting it right is crucial:
| Learning Rate | What Happens |
|---|---|
| Too small | Takes forever, might get stuck |
| Too large | Overshoots the valley, bounces around, might explode |
| Just right | Smooth convergence to a good minimum |
# Typical ranges to try
lr = 1e-3 # 0.001 - common starting point
lr = 1e-4 # 0.0001 - smaller, more stable
lr = 3e-4 # 0.0003 - often works well for Adam optimizer
Pro tip: if your loss goes to NaN or infinity, your learning rate is too high.
Three flavors of gradient descent
Batch (full) gradient descent: Compute gradient using the entire dataset. Very accurate direction, but slow.
for epoch in range(epochs):
gradient = compute_gradient(all_data) # expensive!
weights -= lr * gradient
Stochastic gradient descent (SGD): Compute gradient on a single sample. Fast but noisy (zigzags a lot).
for sample in dataset:
gradient = compute_gradient(sample) # cheap!
weights -= lr * gradient
Mini-batch gradient descent: Compute gradient on small batch. Best of both worlds.
for batch in dataloader: # batch_size = 32, 64, 128...
gradient = compute_gradient(batch)
weights -= lr * gradient
This is what everyone uses in practice.
Momentum
SGD is noisy, oscillates. Momentum smooths it out.
Keep running average of gradients: $$v_t = \beta v_{t-1} + \nabla L$$ $$w = w - \alpha v_t$$
Like a ball rolling downhill - it builds up speed in consistent directions.
v = 0
for batch in dataloader:
gradient = compute_gradient(batch)
v = beta * v + gradient
weights -= lr * v
Adam optimizer
Adaptive learning rate per parameter. Most popular optimizer.
Combines momentum with adaptive scaling:
- Parameters with large gradients: smaller effective learning rate
- Parameters with small gradients: larger effective learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for batch in dataloader:
optimizer.zero_grad()
loss = compute_loss(batch)
loss.backward()
optimizer.step()
Adam usually “just works” but sometimes SGD+momentum generalizes better.
Learning rate scheduling
Learning rate should often decrease during training.
Step decay: Reduce by factor every N epochs
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
Cosine annealing: Smooth decrease following cosine curve
scheduler = CosineAnnealingLR(optimizer, T_max=100)
Warmup: Start small, increase, then decrease
# Linear warmup for first 1000 steps
# Then decay
Local minima and saddle points
Loss surface isn’t simple bowl. Has:
- Local minima (not globally optimal)
- Saddle points (minimum in some directions, maximum in others)
Scary but in high dimensions, most “bad” critical points are saddle points. Noise from mini-batches helps escape them.
Practical tips
- Start with Adam, lr=1e-3 or 3e-4
- Use learning rate warmup for large models
- Monitor loss curves - should decrease smoothly
- If loss explodes, reduce learning rate
- Try SGD+momentum for final fine-tuning
These animations helped you understand gradient descent? Show some love with a ⭐ on GitHub and share with your ML community!