Training Stability & Monitoring
Chapter 13 built the Transformer, and its core has been remarkably durable: stacked blocks of attention and feed-forward, residual connections, normalization. Yet GPT-2, LLaMA-3, and the latest frontier models differ in dozens of details. This chapter surveys those variations — the architectural choices that, layered on the stable core, distinguish each generation and each family of models.
What Stays, What Changes
| Unchanged since 2017 | Refined over time |
|---|---|
| Self-attention as the mixing operation | Positional encoding (learned → RoPE) |
| Feed-forward per position | Activation (ReLU → GELU → SwiGLU) |
| Residual connections | Normalization (LayerNorm → RMSNorm) |
| Stacked identical blocks | Norm placement (Post-LN → Pre-LN) |
| Softmax attention weights | Attention efficiency (MHA → GQA, Flash) |
| Token embeddings | Sparsity (dense → Mixture of Experts) |
The single most consequential architectural choice is the overall structure, determined by the attention masking and whether there is a separate encoder. Chapter 13 introduced these; here we treat them as a design space and examine why the field converged on decoder-only for generative LLMs.
| Family | Attention | Pretraining | Examples |
|---|---|---|---|
| Encoder-only | Bidirectional | Masked LM (MLM) | BERT, RoBERTa, DeBERTa |
| Decoder-only | Causal | Next-token | GPT, LLaMA, Claude, Mistral |
| Encoder-decoder | Bidir. enc + causal dec | Span corruption | T5, BART, Flan-T5 |
Encoder-Only: Understanding
Encoder-only models (BERT; Devlin et al., 2019) use bidirectional attention — every token sees every other — and train by masking random tokens and predicting them (masked language modeling). They excel at understanding tasks: classification, named-entity recognition, retrieval embeddings. They cannot generate text autoregressively, because every position already sees the future.
Decoder-Only: Generation
Decoder-only models use causal attention and next-token prediction. As Chapter 13 argued, they won for general-purpose LLMs because one model and one objective serve every task via prompting, every token provides a training signal, and the architecture scales cleanly. The entire generative frontier — GPT-4, LLaMA, Claude, Gemini — is decoder-only.
Encoder-Decoder: Seq-to-Seq
Encoder-decoder models (T5; Raffel et al., 2020) keep a bidirectional encoder and a causal decoder joined by cross-attention. They are natural for tasks with a clear input-to-output mapping: translation, summarization. T5's 'span corruption' objective masks contiguous spans rather than single tokens. They remain competitive for focused seq-to-seq tasks but have been eclipsed by decoder-only models for general use.
Chapter 13 introduced positional encoding and previewed RoPE. Here we examine the variants in depth, because the choice profoundly affects a model's ability to handle long context — a central concern for modern LLMs. The trend has moved decisively from absolute positions toward relative schemes that extrapolate better.
| Method | Mechanism | Long-context extrapolation |
|---|---|---|
| Sinusoidal | Add fixed sin/cos vectors | Poor beyond training length |
| Learned absolute | Add a learned vector per position | None — fails past max length |
| RoPE | Rotate Q,K by position angle | Good; extendable via scaling |
| ALiBi | Add distance penalty to scores | Strong; trains short, tests long |
| Relative (T5) | Learned bias per relative distance | Good within bucketed range |
Rotary Position Embedding (RoPE)
RoPE (Su et al., 2021) is the dominant choice in modern LLMs. Instead of adding a position vector, it ROTATES the query and key vectors by an angle proportional to their absolute position. The magic: when you take the dot product of a rotated query and a rotated key, the result depends only on their RELATIVE position — absolute position cancels out. This gives relative-position awareness for free, with no added parameters.
Rotate q at position m and k at position n by their position angles:
q̃_m = R(mθ) q k̃_n = R(nθ) k
Dot product depends only on (m - n):
q̃_m · k̃_n = qᵀ R((m-n)θ) k # relative position!import torch
def build_rope(seq_len, dim, base=10000):
"""Precompute cos/sin tables for RoPE. dim must be even."""
# Frequencies: lower dims rotate fast, higher dims slow
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
pos = torch.arange(seq_len).float()
angles = torch.outer(pos, inv_freq) # (seq_len, dim/2)
return angles.cos(), angles.sin()
def apply_rope(x, cos, sin):
"""Rotate x (..., seq, dim) by the precomputed angles."""
x1, x2 = x[..., ::2], x[..., 1::2] # even/odd dims
# Rotate each (x1, x2) pair by its angle
rot1 = x1 * cos - x2 * sin
rot2 = x1 * sin + x2 * cos
return torch.stack([rot1, rot2], dim=-1).flatten(-2)
# Applied to Q and K (NOT V) just before the attention dot product.
# No learned parameters; position enters purely through rotation.ALiBi: Attention with Linear Biases
ALiBi (Press et al., 2021) takes an even simpler approach: it adds NO positional encoding to the embeddings at all. Instead, it adds a distance-proportional penalty directly to the attention scores — the farther back a key is, the more its score is reduced. This biases each query toward nearby tokens and, remarkably, lets a model trained on short sequences generalize to much longer ones at inference.
scores[i][j] = q_i · k_j - m · (i - j) # penalize distance
# m is a fixed per-head slope (smaller m = longer effective range)
# No position embeddings; the bias IS the positional signal.Chapters 10 and 13 introduced normalization and activations; modern LLMs use specific refined choices that improve stability and quality at scale. We consolidate the modern recipe here.
RMSNorm Over LayerNorm
RMSNorm (Chapter 10) drops LayerNorm's mean-centering, normalizing only by the root-mean-square. It is cheaper (no mean computation, no bias) and empirically matches or beats LayerNorm. LLaMA, Gemma, and most recent models use it. The savings are small per call but add up across hundreds of layers and trillions of tokens.
SwiGLU Over GELU
The SwiGLU feed-forward (Chapter 10) uses a gated activation with three weight matrices instead of two, with the hidden dimension reduced to (8/3)d to keep the parameter count matched. It consistently improves perplexity at equal parameters and is the standard FFN in LLaMA and PaLM.
Pre-LN and QK-Norm
Pre-LN (normalize before each sublayer; Chapter 13) keeps the residual stream clean and is essential for training deep models stably. A newer refinement, QK-normalization, applies normalization to the queries and keys before the attention dot product, which stabilizes training of very large models by preventing attention-logit explosion.
| Component | Old default | Modern default |
|---|---|---|
| Normalization | LayerNorm | RMSNorm |
| Norm placement | Post-LN | Pre-LN (+ sometimes QK-norm) |
| FFN activation | ReLU / GELU | SwiGLU |
| FFN hidden size | 4d | (8/3)d for SwiGLU (param-matched) |
| Positional | Learned absolute | RoPE |
| Bias terms | Present | Often removed (LLaMA drops most biases) |
The attention-efficiency variants in the next sections are all responses to one problem: the KV-cache. Recall from Chapter 13 that autoregressive generation caches the keys and values of past tokens to avoid recomputing them. This cache is fast, but it is large — and at long context and large batch sizes, it becomes the dominant memory cost and the bottleneck on inference throughput.
The Size of the KV-Cache
KV-cache = 2 · n_layers · n_heads · head_dim · seq_len · batch · bytes
# the 2 is for Keys AND Values
Example: 70B model (80 layers, 64 heads, 128 dim), 8k context, batch 32, bf16:
2 × 80 × 64 × 128 × 8192 × 32 × 2 ≈ 860 GB860GB of KV-cache for a single batch — far more than the model weights themselves. This is why long-context, high-throughput inference is hard: the cache, not the parameters, fills the GPU memory. The number of attention heads is a direct multiplier, which points to the fix: reduce the number of distinct keys and values by sharing them across query heads.
Multi-Query Attention (MQA; Shazeer, 2019) and Grouped-Query Attention (GQA; Ainslie et al., 2023) shrink the KV-cache by sharing keys and values across query heads. Standard multi-head attention gives every head its own K and V; MQA gives ALL heads a single shared K and V; GQA is the middle ground, sharing K and V within small groups of query heads.
The Spectrum
| Variant | K/V heads | KV-cache | Quality |
|---|---|---|---|
| Multi-head (MHA) | = query heads (e.g. 64) | Largest | Best |
| Grouped-query (GQA) | Groups (e.g. 8) | Reduced ~8× | ≈ MHA |
| Multi-query (MQA) | 1 shared | Smallest | Slight drop |
GQA hits the sweet spot and is now standard in LLaMA-2 70B, LLaMA-3, Mistral, and most recent models. With 64 query heads and 8 KV groups, the KV-cache shrinks 8× with negligible quality loss — a large inference win for almost no cost. Here is the head-sharing layout:
Device Grid: Grouped-Query Attention: 8 query heads, 2 KV groups
| Q0 | Q1 | Q2 | Q3 | Q4 | Q5 | Q6 | Q7 | |
|---|---|---|---|---|---|---|---|---|
| K/V | KV-A | KV-A | KV-A | KV-A | KV-B | KV-B | KV-B | KV-B |
import torch; import torch.nn.functional as F
def grouped_query_attention(Q, K, V, n_groups):
"""Q: (B, n_q_heads, T, d) K,V: (B, n_groups, T, d)."""
B, n_q, T, d = Q.shape
heads_per_group = n_q // n_groups
# Repeat each K/V group to match its query heads
K = K.repeat_interleave(heads_per_group, dim=1) # (B, n_q, T, d)
V = V.repeat_interleave(heads_per_group, dim=1)
# Standard attention, but K/V were shared within groups
scores = (Q @ K.transpose(-2, -1)) / d**0.5
return (F.softmax(scores, dim=-1) @ V)
# n_groups = n_q -> standard MHA (no sharing)
# n_groups = 1 -> MQA (all heads share one K/V)
# n_groups = 8 -> GQA (LLaMA-3 uses this)
# The KV-cache stores only n_groups K/V heads instead of n_q,
# shrinking it by n_q/n_groups -- e.g. 64/8 = 8x smaller.FlashAttention (Dao et al., 2022) is one of the most impactful systems contributions to LLMs. It computes EXACT attention — no approximation — but does so without ever materializing the full (T×T) attention score matrix in slow GPU memory. By restructuring the computation to keep data in fast on-chip SRAM, it dramatically reduces memory traffic and speeds up both training and inference.
The Key Idea: Tiling and Online Softmax
Standard attention writes the entire T×T score matrix to GPU high-bandwidth memory (HBM), then reads it back for the softmax and the value multiplication — a huge amount of slow memory traffic. FlashAttention instead processes attention in tiles that fit in fast SRAM, computing a running 'online softmax' that never needs the full matrix in HBM. The result is identical, but the memory traffic drops from O(T²) to O(T).
Standard: write & read T×T score matrix to/from HBM → O(T²) HBM traffic
FlashAttention: tile Q,K,V into SRAM-sized blocks,
compute attention block-by-block with online softmax,
never store the full T×T matrix → O(T) HBM traffic
Same exact result; far less slow-memory traffic → 2-4× faster.Why It Also Saves Memory
Because FlashAttention never materializes the T×T matrix, its memory footprint is linear in sequence length rather than quadratic. This is what made long-context training practical: a 32k-token sequence would need a 32k×32k = 1-billion-entry score matrix per head with standard attention, but FlashAttention needs only the linear-sized tiles. Long context owes much of its feasibility to this single kernel.
FlashAttention makes exact attention faster but does not change its O(T²) compute. For very long contexts, even that is too expensive, so a family of variants APPROXIMATES attention by having each token attend to only a subset of positions, trading some expressiveness for sub-quadratic cost.
| Pattern | Each token attends to | Cost |
|---|---|---|
| Sliding window | A fixed window of w nearby tokens | O(T·w) linear |
| Dilated / strided | Every k-th token (gaps) | Sub-quadratic |
| Global + local | A few global tokens + local window | O(T·w + T·g) |
| Block-sparse | Selected blocks of the matrix | Configurable |
| Random (BigBird) | Window + global + random tokens | O(T) linear |
Sliding-Window Attention
The simplest and most widely used sparse pattern is the sliding window (used in Longformer and Mistral): each token attends only to the w previous tokens, not the whole history. This makes attention linear in sequence length. Crucially, stacking layers gives an effective receptive field that grows with depth — just as in a CNN — so information can still propagate across the full sequence, indirectly, through the layer stack.
Each layer attends to w nearby tokens (window size w).
After L layers, the effective receptive field ≈ L × w tokens.
Mistral 7B: w = 4096, 32 layers → effective ~131k token reach,
while each attention op stays O(T · 4096), not O(T²).import torch
def sliding_window_mask(seq_len, window):
"""Causal mask where each token sees only `window` previous tokens."""
i = torch.arange(seq_len)[:, None] # query positions
j = torch.arange(seq_len)[None, :] # key positions
# Attend to j where: j <= i (causal) AND i - j < window
mask = (j <= i) & (i - j < window)
return mask
m = sliding_window_mask(6, window=3)
print(m.int())
# [[1 0 0 0 0 0]
# [1 1 0 0 0 0]
# [1 1 1 0 0 0]
# [0 1 1 1 0 0] <- token 3 no longer sees token 0 (outside window)
# [0 0 1 1 1 0]
# [0 0 0 1 1 1]]The most significant architectural departure from the dense Transformer is the Mixture of Experts (MoE). It addresses a fundamental tension: more parameters mean a more capable model, but also more compute per token. MoE breaks this link by activating only a small subset of parameters for each token. This chapter previews the idea; Chapter 32 covers it in full.
The Core Idea
In a dense model, every token passes through every parameter. In an MoE, the feed-forward layer is replaced by many parallel 'expert' FFNs plus a small 'router' network. For each token, the router selects only the top-k experts (typically 1 or 2 of many), so each token uses only a fraction of the total parameters. The model has a huge parameter count but a small ACTIVE parameter count per token.
Dense FFN: every token uses all N_ffn parameters.
MoE: E experts, router picks top-k per token.
total params = E × N_expert (huge)
active params = k × N_expert (small, fixed per token)
Mixtral 8×7B: 47B total params, but only ~13B active per token.The Challenges (Previewed)
We can now assemble the refinements into the anatomy of a contemporary decoder-only LLM — the 'LLaMA-3 recipe' that most open models follow. Compare this to the Chapter 13 baseline to see how the refinements layer on the stable core.
Arch Stack: A modern decoder-only block (LLaMA-3 style)
| + residual | x + FFN_out |
| SwiGLU FFN | d → (8/3)d → d, no bias |
| RMSNorm | pre-norm |
| + residual | x + Attn_out |
| Grouped-Query Attention | RoPE + GQA + FlashAttn |
| RMSNorm | pre-norm |
| input x | (B, T, d) |
| Choice | Chapter 13 baseline | Modern (LLaMA-3) |
|---|---|---|
| Positional | Sinusoidal/learned | RoPE (with scaling for long context) |
| Normalization | LayerNorm | RMSNorm, pre-norm |
| FFN | GELU, 4d | SwiGLU, (8/3)d, no bias |
| Attention | Multi-head | Grouped-query + FlashAttention |
| Biases | Present | Removed |
| Vocab | ~50k | 128k (tiktoken) |
| Context | ~2k | 8k–128k+ (RoPE scaling) |
Variants Quick-Reference
| Variant | Solves | Used in |
|---|---|---|
| Decoder-only | General-purpose generation | GPT, LLaMA, Claude |
| RoPE | Relative position, long context | LLaMA, Mistral, most LLMs |
| ALiBi | Length extrapolation | BLOOM, some long-context models |
| RMSNorm / SwiGLU | Cheaper, better stability/quality | LLaMA, Gemma, PaLM |
| GQA | Smaller KV-cache, fast inference | LLaMA-2/3 70B, Mistral |
| FlashAttention | Faster exact attention | Essentially universal |
| Sliding window | Linear-cost long context | Longformer, Mistral |
| Mixture of Experts | Capacity without compute | Mixtral (full detail Ch. 32) |
Exercises
Exercises 1–10 are pen-and-paper or derivations; 11–20 require code.
Further reading: “RoFormer” (Su et al., 2021) for RoPE and “Train Short, Test Long” (Press et al., 2021) for ALiBi. “Root Mean Square Layer Normalization” (Zhang & Sennrich, 2019) and “GLU Variants Improve Transformer” (Shazeer, 2020) for RMSNorm and SwiGLU. “Fast Transformer Decoding” (Shazeer, 2019, MQA) and “GQA” (Ainslie et al., 2023). “FlashAttention” and “FlashAttention-2” (Dao et al., 2022, 2023). “Longformer” (Beltagy et al., 2020) and “BigBird” (Zaheer et al., 2020) for sparse attention. The LLaMA, Mistral, and Mixtral technical reports for the modern recipe in practice.
Next → Chapter 20: Efficient Training
You now know the architectural variants that make models capable and inference-efficient. Chapter 20 turns to making the TRAINING itself efficient: low-precision number formats (fp8 and beyond), parameter-efficient fine-tuning (LoRA and adapters), the optimized kernels and compilers that raise hardware utilization, and the techniques — quantization-aware training, activation recomputation, fused operations — that squeeze more useful work out of every GPU-hour. These are the methods that determine how much model you can train for a given budget.