Why do we use cross-entropy for classification? What does it actually measure? Understanding this helps understand why neural networks work.
Definition
For true distribution P and predicted distribution Q:
$$H(P, Q) = -\sum_x P(x) \log Q(x)$$
In classification, P is one-hot (true label), Q is softmax output.
See the math: Cross-Entropy Animation
For classification
True label: class 2 (one-hot: [0, 0, 1, 0]) Prediction: [0.1, 0.2, 0.6, 0.1]
$$\text{Loss} = -\sum_i y_i \log(\hat{y}_i) = -\log(0.6) = 0.51$$
Only the true class matters! Others multiply by 0.
Equivalent to: $-\log(\text{predicted probability of true class})$
Why it works
High confidence, correct: -log(0.99) = 0.01. Low loss. Low confidence, correct: -log(0.1) = 2.3. Higher loss. Wrong answer: -log(0.01) = 4.6. Very high loss.
Heavily penalizes confident wrong predictions.
Binary cross-entropy
For binary classification with sigmoid output:
$$\text{BCE} = -y\log(\hat{y}) - (1-y)\log(1-\hat{y})$$
import torch.nn.functional as F
# Binary
loss = F.binary_cross_entropy(predictions, targets)
# Or with logits (more stable)
loss = F.binary_cross_entropy_with_logits(logits, targets)
Multi-class cross-entropy
$$\text{CE} = -\sum_{c=1}^{C} y_c \log(\hat{y}_c)$$
In PyTorch, CrossEntropyLoss expects raw logits (applies softmax internally):
loss_fn = nn.CrossEntropyLoss()
# logits: (batch, num_classes), NOT softmax'd
# targets: (batch,) with class indices
loss = loss_fn(logits, targets)
Why logits preferred
Computing softmax then log is numerically unstable:
- softmax can underflow to 0
- log(0) = -∞
Combining them (log-softmax) is stable: $$\log\text{softmax}(z_i) = z_i - \log\sum_j e^{z_j}$$
# Unstable
probs = F.softmax(logits, dim=-1)
loss = -torch.log(probs[target])
# Stable
loss = F.cross_entropy(logits, target)
Label smoothing
Instead of hard one-hot [0, 0, 1, 0], use soft [0.025, 0.025, 0.925, 0.025].
Prevents overconfidence. Regularization effect.
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
Class imbalance
Some classes rare? Weight the loss:
# Classes weighted inversely to frequency
weights = torch.tensor([0.1, 0.3, 1.0, 0.5])
loss_fn = nn.CrossEntropyLoss(weight=weights)
Focal loss
For extreme imbalance. Down-weight easy examples:
$$\text{FL} = -(1-\hat{y})^\gamma \log(\hat{y})$$
Easy examples (high ŷ) contribute less.
# Not in PyTorch by default, but easy to implement
def focal_loss(logits, targets, gamma=2.0):
ce = F.cross_entropy(logits, targets, reduction='none')
pt = torch.exp(-ce)
return ((1 - pt) ** gamma * ce).mean()