Distributed Training
Chapter 15's training loop assumed a single GPU. Frontier models shatter that assumption. A modern LLM has billions of parameters and trains on trillions of tokens — neither the model nor the data fits or finishes on one device. This chapter builds the techniques that spread a single training run across thousands of GPUs while keeping it mathematically equivalent to the single-device version.
The Memory Budget of Training
To see why one GPU is not enough, count the memory. For a model with N parameters trained in mixed precision with AdamW, each parameter needs storage for several quantities at once:
| Quantity | Bytes/param | Why |
|---|---|---|
| bf16 parameters | 2 | The weights used in forward/backward |
| bf16 gradients | 2 | Computed in the backward pass |
| fp32 master weights | 4 | High-precision copy for updates |
| fp32 Adam momentum (m) | 4 | First moment estimate |
| fp32 Adam variance (v) | 4 | Second moment estimate |
| Total (model+optimizer) | 16 | Per parameter, before activations |
Memory(model + optimizer) ≈ 16 · N bytes
7B model: 16 × 7e9 = 112 GB (an 80GB GPU cannot hold it)
70B model: 16 × 70e9 = 1,120 GB (needs ~14+ GPUs just for state)
+ activations, which grow with batch × sequence lengthTwo Things to Distribute
Distribution solves two distinct problems, and the parallelism strategies map onto them. DATA parallelism addresses throughput — processing trillions of tokens in reasonable time by having many GPUs work on different data. MODEL parallelism (tensor and pipeline) addresses capacity — fitting a model too large for one GPU by splitting the model itself across devices. Real systems combine both.
Before the parallelism strategies, we need the primitives they use to coordinate. Distributed training is built on a small set of collective communication operations — implemented in libraries like NCCL (NVIDIA) and exposed by PyTorch's torch.distributed. Understanding these five operations is enough to understand every parallelism strategy in this chapter.
| Operation | What it does |
|---|---|
| Broadcast | Send one GPU's tensor to all GPUs |
| Reduce | Combine (sum/mean) tensors from all GPUs onto one GPU |
| All-Reduce | Reduce, then broadcast result — all GPUs get the sum |
| All-Gather | Each GPU collects the shards from all GPUs (concatenate) |
| Reduce-Scatter | Reduce, then each GPU keeps only its shard of the result |
All-reduce is the workhorse of data parallelism: it sums gradients across all GPUs so every GPU ends up with the identical averaged gradient. The crucial efficiency fact is that all-reduce can be implemented as a reduce-scatter followed by an all-gather — the 'ring all-reduce' algorithm — whose communication cost is independent of the number of GPUs, making it scale beautifully.
Naive all-reduce: each GPU sends its data to all others → O(P) per GPU
Ring all-reduce: reduce-scatter + all-gather around a ring
each GPU sends ≈ 2(P-1)/P × (data size) ≈ constant (2× data)
⇒ cost independent of P, the number of GPUsimport torch; import torch.distributed as dist
# Initialize the process group (one process per GPU)
dist.init_process_group(backend='nccl')
rank = dist.get_rank(); world = dist.get_world_size()
x = torch.ones(4).cuda() * rank # each GPU has different data
# All-reduce: every GPU ends up with the SUM across all GPUs
dist.all_reduce(x, op=dist.ReduceOp.SUM)
# With 4 GPUs (ranks 0,1,2,3): x becomes [6,6,6,6] on every GPU
# All-gather: collect each GPU's shard into a full tensor
gathered = [torch.zeros(4).cuda() for _ in range(world)]
dist.all_gather(gathered, x) # every GPU now has all shards
# Reduce-scatter: sum across GPUs, but each keeps only its slice
out = torch.zeros(1).cuda()
dist.reduce_scatter(out, list(x.chunk(world)))Data parallelism is the foundation. The idea is simple: replicate the entire model on every GPU, give each a different slice of the batch, and average the gradients. Because every replica processes different data and then synchronizes gradients, the result is mathematically identical to training on one giant batch on one device — but it runs many times faster.
The Data-Parallel Algorithm
# Every GPU holds a full copy of the model
replicate model on all P GPUs
each step:
split the global batch into P micro-batches (one per GPU)
each GPU: forward + backward on its micro-batch
all-reduce gradients across GPUs (sum, then divide by P)
each GPU: optimizer.step() with the SAME averaged gradient
# all replicas stay identical because they apply identical updatesHere is how data parallelism lays out across 4 GPUs: each holds the complete model, and they differ only in which batch shard they process. The gradients are synchronized so all GPUs apply the same update.
Device Grid: Data parallelism: full model replicated, batch sharded
| GPU 0 | GPU 1 | GPU 2 | GPU 3 | |
|---|---|---|---|---|
| Model | full | full | full | full |
| Batch | shard 0 | shard 1 | shard 2 | shard 3 |
import torch; import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
dist.init_process_group('nccl')
local_rank = int(os.environ['LOCAL_RANK']); torch.cuda.set_device(local_rank)
model = GPT(...).cuda()
model = DDP(model, device_ids=[local_rank]) # wraps model; hooks gradients
opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
for batch in sharded_dataloader: # each rank gets a different shard
loss = lm_loss(model, batch.cuda())
loss.backward() # DDP all-reduces grads automatically
opt.step(); opt.zero_grad()
# DDP overlaps gradient all-reduce WITH the backward pass: as soon as
# a layer's gradients are ready, their all-reduce starts while earlier
# layers are still computing. This hides most communication latency.Plain data parallelism is wasteful: every GPU stores a full, identical copy of the model, gradients, and optimizer state. If you have 64 GPUs, you store 64 redundant copies. ZeRO (Zero Redundancy Optimizer; Rajbhandari et al., 2020) eliminates this redundancy by SHARDING the state across GPUs — each GPU holds only its slice — while preserving the simplicity and mathematics of data parallelism.
The Three ZeRO Stages
ZeRO comes in three progressively more aggressive stages, each sharding more of the per-parameter state. Recall from Section 18.1 that the 16 bytes/param split into optimizer state (12), gradients (2), and parameters (2). ZeRO shards them in that order:
| Stage | Shards | Memory/GPU (vs DDP) | Extra comm. |
|---|---|---|---|
| ZeRO-1 | Optimizer state | 16 → 4 + 12/P | Minimal |
| ZeRO-2 | + Gradients | 16 → 2 + 14/P | Minimal |
| ZeRO-3 | + Parameters | 16/P (full shard) | All-gather params |
ZeRO-1 shards only the optimizer state (the 12 redundant bytes), giving most of the memory savings for almost no extra communication. ZeRO-2 additionally shards gradients. ZeRO-3 shards the parameters themselves, so each GPU holds only 1/P of the model — the full per-parameter memory drops from 16 bytes to 16/P, enabling truly enormous models, at the cost of all-gathering parameters on demand during forward and backward.
DDP: 16N bytes (no sharding)
ZeRO-1: 4N + 12N/P (optimizer sharded)
ZeRO-2: 2N + 14N/P (+ gradients sharded)
ZeRO-3: 16N/P (everything sharded)
For N=7B, P=64: DDP=112GB, ZeRO-3 ≈ 1.75GB per GPUFully Sharded Data Parallel (FSDP) is PyTorch's native implementation of ZeRO-3-style parameter sharding. It is the standard way to train large models in PyTorch today, having largely superseded the older approaches for most use cases. Understanding its mechanics makes the ZeRO ideas concrete.
How FSDP Works
# Parameters are sharded across GPUs; each holds 1/P
forward(layer):
all-gather the layer's full parameters from all GPUs
compute the forward pass
free the gathered parameters (keep only own shard)
backward(layer):
all-gather the layer's full parameters again
compute gradients
reduce-scatter gradients (each GPU keeps its shard)
free the gathered parameters
# optimizer updates only the local shardimport torch; from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import functools
model = GPT(...).cuda()
# Wrap each Transformer block as a separate FSDP unit -- so only ONE
# block's parameters are gathered at a time, minimizing peak memory
policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerBlock},
)
model = FSDP(
model,
auto_wrap_policy=policy,
sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3 style
mixed_precision=MixedPrecision(param_dtype=torch.bfloat16),
)
# Sharding strategies trade memory for communication:
# FULL_SHARD (ZeRO-3): shard params+grads+optimizer -- least memory
# SHARD_GRAD_OP (ZeRO-2): shard grads+optimizer only
# NO_SHARD (= DDP): no sharding
# HYBRID_SHARD: full-shard within a node, replicate across nodesData parallelism and ZeRO replicate or shard the model across the batch dimension. Tensor parallelism (Shoeybi et al., 2019, 'Megatron-LM') takes a different cut: it splits the computation WITHIN each layer across GPUs. A single matrix multiplication is partitioned so that each GPU computes part of it, and the results are combined. This is essential when even a single layer's activations are too large for one GPU.
Splitting a Linear Layer
Consider a linear layer Y = XW. There are two ways to split W across GPUs. Column parallelism splits W by columns, so each GPU produces a slice of the output. Row parallelism splits W by rows, so each GPU produces a partial sum that must be all-reduced. Megatron cleverly pairs them so that an FFN or attention block needs only one all-reduce.
Column split: W = [W₁ | W₂] Y = X[W₁|W₂] = [XW₁ | XW₂]
each GPU computes one output slice; concatenate (all-gather)
Row split: W = [W₁; W₂] Y = [X₁|X₂][W₁;W₂] = X₁W₁ + X₂W₂
each GPU computes a partial sum; combine (all-reduce)Megatron's insight: in the FFN (two linear layers, d→4d→d), make the first layer column-parallel and the second row-parallel. Then each GPU independently computes its slice through both layers, and only ONE all-reduce is needed at the end. The same trick applies to attention by splitting across heads. Here is how the FFN weights split across 2 GPUs:
Device Grid: Tensor-parallel FFN across 2 GPUs (Megatron)
| GPU 0 | GPU 1 | |
|---|---|---|
| W1 (d→4d) | cols 0..2d | cols 2d..4d |
| GELU | local | local |
| W2 (4d→d) | rows 0..2d | rows 2d..4d |
import torch; import torch.distributed as dist
class ColumnParallelLinear:
"""Y = X @ W, with W split by columns across GPUs."""
def __init__(self, d_in, d_out, tp_size):
self.W = torch.empty(d_in, d_out // tp_size).cuda() # local slice
def forward(self, x):
return x @ self.W # (B,T,d_out/tp) local output slice
class RowParallelLinear:
"""Y = X @ W, with W split by rows; output all-reduced."""
def __init__(self, d_in, d_out, tp_size):
self.W = torch.empty(d_in // tp_size, d_out).cuda()
def forward(self, x): # x is already the local slice
out = x @ self.W # partial sum (B,T,d_out)
dist.all_reduce(out) # sum partial results -> full output
return out
# FFN = RowParallel(GELU(ColumnParallel(x)))
# Column produces local slices, row consumes them and all-reduces ONCE.
# This single all-reduce per FFN/attention block is Megatron's key efficiency.Pipeline parallelism (Huang et al., 2019, 'GPipe') splits the model by LAYERS: GPU 0 holds the first few layers, GPU 1 the next few, and so on. Activations flow forward through the stages and gradients flow backward. Because each GPU holds only a contiguous block of layers, the model can be far larger than any single device — and because stages communicate only at their boundaries, the communication is infrequent and tolerant of slower inter-node links.
The Pipeline Bubble
Pipeline parallelism has a fundamental inefficiency: the bubble. While GPU 0 processes the first micro-batch, GPUs 1–N sit idle waiting for activations. At the end, GPU 0 is idle while the last stages finish. This idle time — the bubble — wastes compute. The fix is micro-batching: split the batch into many micro-batches and feed them through in a staggered pipeline, so all stages stay busy most of the time.
Device Grid: Pipeline parallelism: model layers split across 4 GPUs
| GPU 0 | GPU 1 | GPU 2 | GPU 3 | |
|---|---|---|---|---|
| Layers | 1–8 | 9–16 | 17–24 | 25–32 |
P pipeline stages, m micro-batches:
bubble fraction = (P - 1) / (m + P - 1)
P=4, m=1: bubble = 3/4 = 75% wasted!
P=4, m=32: bubble = 3/35 ≈ 8.6% (more micro-batches → smaller bubble)The bubble shrinks as you increase the number of micro-batches m relative to the number of stages P. This is why pipeline parallelism uses many micro-batches. Advanced schedules — GPipe's simple fill-drain, PipeDream's 1F1B (one-forward-one-backward), and interleaved schedules — further reduce the bubble and the activation memory it requires.
# 1F1B: once the pipeline is full, alternate one forward and one
# backward per step, keeping all stages busy and bounding memory.
def pipeline_1f1b(stage, micro_batches, n_stages, rank):
warmup = n_stages - rank - 1 # forwards before first backward
# Warmup phase: fill the pipeline
for _ in range(warmup):
act = stage.forward(recv_from_prev())
send_to_next(act)
# Steady state: 1 forward + 1 backward per iteration
for mb in remaining_micro_batches:
act = stage.forward(recv_from_prev()); send_to_next(act)
grad = stage.backward(recv_grad_from_next()); send_grad_to_prev(grad)
# Cooldown: drain remaining backwards
for _ in range(warmup):
grad = stage.backward(recv_grad_from_next()); send_grad_to_prev(grad)
# 1F1B keeps at most `warmup` activations in memory per stage,
# vs GPipe which holds ALL micro-batches' activations at once.No single parallelism strategy suffices for the largest models. The state of the art combines all three — data, tensor, and pipeline parallelism — into 3D parallelism (Narayanan et al., 2021, 'Megatron-LM' at scale). Each strategy handles the dimension it is best suited to, and together they map a model onto thousands of GPUs efficiently.
The Three Dimensions
| Dimension | Splits | Communication | Hardware |
|---|---|---|---|
| Tensor (TP) | Within each layer | Per-layer all-reduce | Within node (NVLink) |
| Pipeline (PP) | Across layer groups | Stage boundaries | Across nodes |
| Data (DP) | Across the batch | Gradient all-reduce | Outermost |
The total GPU count is the product of the three degrees: total = TP × PP × DP. A typical large-model configuration might use TP=8 (one node), PP=8 (eight nodes form one pipeline), and DP=16 (sixteen such pipelines in parallel) — 8×8×16 = 1,024 GPUs training one model. ZeRO can be layered on the data-parallel dimension to shard its optimizer state too.
Total GPUs = TP × PP × DP
Example (1024 GPUs, 128 nodes × 8 GPUs):
TP = 8 each node holds one tensor-parallel group
PP = 8 8 nodes form one pipeline (the full model)
DP = 16 16 pipelines train in data-parallel
⇒ 8 × 8 × 16 = 1024 GPUsWith all the strategies in hand, how do you choose? The decision depends on model size, cluster size, and interconnect. Here is a practical decision guide that captures current best practice.
| Situation | Recommended strategy |
|---|---|
| Model fits on one GPU | Plain DDP (data parallel) |
| Model fits, optimizer state too big | ZeRO-1 / ZeRO-2 (or FSDP SHARD_GRAD_OP) |
| Model does not fit on one GPU | FSDP / ZeRO-3 (shard everything) |
| Model too big even for one node | + Tensor parallelism within nodes |
| Model spans many nodes | + Pipeline parallelism across nodes |
| Frontier scale (100B+) | Full 3D parallelism + ZeRO on DP dim |
The Modern Default
For most practitioners training models up to tens of billions of parameters on a single multi-GPU node or a small cluster, FSDP (ZeRO-3) is the pragmatic default: it is built into PyTorch, requires no model surgery, and shards everything. Tensor and pipeline parallelism enter only at the largest scales, where they require deliberate model partitioning via frameworks like Megatron-LM or DeepSpeed.
# Launch with torchrun across all GPUs on all nodes:
# torchrun --nnodes=128 --nproc_per_node=8 train.py
import os, torch; import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
def main():
dist.init_process_group('nccl')
rank = dist.get_rank(); local = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local)
model = GPT(...).cuda()
model = FSDP(model, auto_wrap_policy=block_policy,
sharding_strategy=ShardingStrategy.HYBRID_SHARD) # shard in-node, replicate cross-node
opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
for step, batch in enumerate(distributed_loader(rank)):
with torch.autocast('cuda', dtype=torch.bfloat16):
loss = lm_loss(model, batch.cuda())
loss.backward() # FSDP handles reduce-scatter
model.clip_grad_norm_(1.0)
opt.step(); opt.zero_grad()
if rank == 0 and step % 100 == 0:
print(f"step {step}: loss {loss.item():.3f}")
dist.destroy_process_group()A frontier training run uses thousands of GPUs for weeks or months. At that scale, hardware failures are not exceptional — they are routine. A GPU fails, a network link flaps, a node crashes. The infrastructure must detect, recover, and continue without losing the run. This operational reality is as important as the parallelism math.
The Scale of Failure
With thousands of GPUs running continuously, the mean time between failures is measured in hours, not months. Meta's LLaMA-3 training (Llama Team, 2024) reported hundreds of interruptions over its run, the majority from hardware faults. A single failed GPU can stall the entire synchronized run, so the system must handle failures gracefully.
| Challenge | Mitigation |
|---|---|
| GPU/node failures | Frequent checkpointing; elastic restart from last checkpoint |
| Stragglers | Detect slow GPUs; the synchronized step runs at the speed of the slowest |
| Network congestion | Topology-aware placement; overlap comm. with compute |
| Silent data corruption | Checksums, redundant computation, monitoring for loss anomalies |
| Checkpoint cost | Asynchronous and sharded checkpointing to avoid stalling training |
| Memory fragmentation | Activation checkpointing, careful allocator tuning |
Parallelism Quick-Reference
| Strategy | Splits | Comm. frequency | Scope |
|---|---|---|---|
| Data (DDP) | Batch | Per step (grad) | Any |
| ZeRO / FSDP | Optimizer/grad/params | Per layer | Any |
| Tensor | Within layer | Per layer | In node |
| Pipeline | Across layers | Stage boundary | Across nodes |
| 3D | All three | Mixed | Whole cluster |
Exercises
Exercises 1–10 are pen-and-paper or derivations; 11–20 require code.
Further reading: “ZeRO: Memory Optimizations Toward Training Trillion Parameter Models” (Rajbhandari et al., 2020). “Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism” (Shoeybi et al., 2019) and “Efficient Large-Scale Language Model Training on GPU Clusters” (Narayanan et al., 2021) for tensor and 3D parallelism. “GPipe” (Huang et al., 2019) and “PipeDream” (Narayanan et al., 2019) for pipeline schedules. The PyTorch FSDP paper and documentation (Zhao et al., 2023). The DeepSpeed and Megatron-LM codebases. The LLaMA-3 technical report (2024) for a candid account of training infrastructure at scale.
Next → Chapter 19: Architecture Variants
You can now spread a training run across thousands of GPUs. Chapter 19 returns to the model itself, surveying the architectural variants that make large models more capable and more efficient: the modern positional encodings (RoPE, ALiBi) for longer context, normalization and activation refinements (RMSNorm, SwiGLU), attention efficiency variants (multi-query and grouped-query attention) that shrink the memory bottleneck, and the design choices that distinguish GPT, LLaMA, and the other model families. These are the refinements that, layered on the stable Transformer core, define each generation of frontier models.