Training Transformers
Detailed solutions for the exercises in Chapter 15. Try solving them yourself before checking the answers.
Solution
The loss is −Σ_t log P(x_{t+1} | x_{≤1t}) — at each position the target is the NEXT token, so inputs and targets are the same sequence shifted by one. Because the causal mask lets every position predict its successor simultaneously in a single forward pass, one sequence of length T yields T separate next-token training signals — the dense supervision that makes language-model pretraining so sample-efficient per forward pass.
Solution
Perplexity = e^{loss} = e^{1.6} ≈ 4.95 ≈ 5. A perplexity of 5 means the model is on average as uncertain as if it were choosing uniformly among ~5 equally-likely next tokens — its effective branching factor. Lower perplexity = sharper, more confident (and usually more accurate) next-token predictions.
Solution
Adam's second-moment estimate v is unreliable in the first steps (few samples), so early updates can be erratically large, destabilizing the sensitive Transformer (with its LayerNorms and residuals). Warmup ramps the learning rate up gradually so the moment estimates stabilize before large steps are taken. Plain SGD has no adaptive denominator to mis-estimate, and CNNs are less sensitive, so they often train fine without warmup.
Solution
For step t with warmup W and total T: lr(t) = lr_max·(t/W) for t ≤ W (linear warmup), and lr(t) = lr_min + ½(lr_max−lr_min)(1 + cos(π·(t−W)/(T−W))) for t > W (cosine decay). At t = W, the warmup branch gives lr_max, and the cosine branch gives lr_min + ½(lr_max−lr_min)(1+cos0) = lr_min + (lr_max−lr_min) = lr_max. The two branches agree at the boundary, so the schedule is continuous.
Solution
In Adam+L2 the weight-decay term is added to the gradient and then divided by √v (the per-parameter second moment). Parameters with large gradients have large v, so their effective decay (decay/√v) is shrunk — they get less regularization precisely where you might want more. AdamW DECOUPLES weight decay, applying it directly to the weights (w ← w(1−ηλ)) outside the adaptive scaling, so every parameter is decayed uniformly regardless of its gradient magnitude — the reason AdamW generalizes better.
Solution
Weight matrices benefit from shrinkage (regularization toward simpler functions). LayerNorm gains (γ) and biases set the SCALE and SHIFT of normalized activations; decaying γ toward 0 would suppress the layer's output magnitude, effectively killing the signal the normalization is meant to pass through, harming the model. Biases carry no capacity worth regularizing. So decay is applied to the large weight matrices only, and excluded from norms and biases.
Solution
When the total norm g_norm exceeds the threshold, every gradient is multiplied by the same positive scalar c = max_norm/g_norm < 1. Scaling all components of a vector by one positive constant changes its magnitude (to exactly max_norm) but not its direction (it still points the same way). So clipping rescales the step without changing where it points — it tames magnitude while keeping the descent direction intact.
Solution
bf16 keeps fp32's 8-bit exponent (same dynamic range, ~10±³⁸) but with fewer mantissa bits, so it rarely overflows or underflows and needs NO loss scaling. fp16 has only a 5-bit exponent (narrow range), so small gradients underflow and require dynamic loss scaling to survive. bf16's wide range makes mixed-precision training simpler and more robust, which is why it became the default on modern hardware despite slightly lower precision.
Solution
The goal is for K micro-batches to produce the same gradient as one batch K times larger, i.e. the AVERAGE gradient over all examples. Summed micro-batch gradients without division would be K× too large (a sum, not a mean), effectively multiplying the learning rate by K. Dividing each micro-batch loss by K makes the accumulated gradients average correctly. Forgetting the division inflates the effective step size by K, often causing divergence.
Solution
The critical batch size is the point below which gradient noise dominates (more samples per step help a lot) and above which returns diminish (the gradient estimate is already accurate, so larger batches mainly waste compute per step). Large models have smoother loss landscapes and tolerate — and benefit from — larger critical batch sizes, letting them use massive data parallelism efficiently. Below the critical size you are noise-limited; above it you are curvature-limited.
Solution
The implemented schedule rises linearly to lr_max over the warmup window, then follows a cosine down to lr_min. Longer warmup delays the peak and gives a gentler start (more stable but slightly slower early progress); shorter warmup reaches full lr sooner (faster but riskier). The three curves visualize this trade-off.
Solution
Iterating parameters and routing 2-D weight matrices to the decay group and 1-D tensors (LayerNorm gains, biases) to the no-decay group reproduces the rule of Exercise 6. Printing the groups confirms norms and biases are excluded from weight decay — the standard, correct optimizer configuration.
Solution
Compute the total norm across all parameter gradients, and if it exceeds max_norm, scale every gradient by max_norm/total_norm (Exercise 7). The result matches PyTorch's clip_grad_norm_ on random gradients, confirming the implementation.
Solution
Without warmup, the early large Adam steps (unreliable second moments, Exercise 3) cause a loss spike or divergence; with warmup the loss descends smoothly. The two curves demonstrate empirically why warmup is standard for Transformer training.
Solution
fp32 uses the most memory and is slowest; fp16 and bf16 roughly halve activation memory and speed up throughput on supported hardware. fp16 needs a GradScaler to avoid underflow; bf16 does not (Exercise 8). Final loss is essentially the same across all three, confirming mixed precision is a free speed/memory win when done correctly.
Solution
Accumulating gradients over 4 micro-batches of 4 (with the 1/accum loss scaling of Exercise 9) and then stepping produces gradients nearly identical to a single batch of 16. This lets you simulate large batches under limited memory — verified by the matching gradients.
Solution
Checkpointing lowers peak memory markedly while increasing wall-clock time by roughly a third (the extra recomputation), reproducing the compute/memory trade-off of Chapter 11's Exercise 10 in a real training run.
Solution
Logging and plotting these five signals reveals training health: perplexity = e^{loss}; grad-norm should be stable (spikes signal instability); lr follows the schedule; and the update-to-weight ratio (≈ lr·‖update‖/‖weight‖) should sit around 1e−3 — too high means unstable, too low means stalled. The dashboard is the practitioner's instrument panel.
Solution
Saving model weights, optimizer moments, scheduler step, and RNG state — then reloading and continuing — should produce a continuation bit-identical to an uninterrupted run. Verifying this confirms the checkpoint captures all training state, essential for long runs that span restarts.
Solution
Combining the schedule, optimizer groups, clipping, mixed precision, accumulation, monitoring, and checkpointing into one loop and training a ~10M-param model yields coherent samples. The post-mortem should note observed instabilities (early loss spikes, grad-norm bursts) and the fixes (warmup, clipping, lr tuning) — the integrated payoff of the whole chapter.