06. Multi-GPU Training
Multi-GPU training accelerates training but introduces complexity that breaks single-GPU code in subtle ways. Understanding the failure modes prevents debugging nightmares.
Choosing a Strategy
| Strategy | Use When | Not When |
|---|---|---|
| Data Parallelism | Fast GPUs, small models | Large models that don't fit on one GPU |
| Model Parallelism | Very large models | Small models (overhead dominates) |
| FSDP | Large models, good interconnects | Slow interconnects (e.g., older AWS instances) |
Setting Up Distributed Training
import torch.distributed as dist
import os
def setup_distributed():
"""Initialize process group for distributed training."""
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
dist.init_process_group(
backend="nccl", # NVIDIA GPUs: always use NCCL
init_method="env://",
world_size=int(os.environ["WORLD_SIZE"]),
rank=int(os.environ["RANK"])
)
return local_rank
def cleanup_distributed():
"""Clean up process group."""
dist.destroy_process_group()
Launching Distributed Jobs
# Single node, 4 GPUs
torchrun --nproc_per_node=4 train.py
# Multi-node (2 nodes, 4 GPUs each)
torchrun \
--nproc_per_node=4 \
--nnodes=2 \
--node_rank=0 \
--master_addr="10.0.0.1" \
--master_port=29500 \
train.py
Common Failure: Batch Size Confusion
In distributed training, the effective batch size is batch_size * num_gpus. A 32-batch-size config trained on 4 GPUs uses an effective batch of 128. Scale learning rate accordingly—linear scaling works for most cases:
# Proper scaling
NUM_GPUS = torch.cuda.device_count()
effective_batch_size = config.batch_size * NUM_GPUS
scaled_lr = config.base_lr * (effective_batch_size / config.base_batch_size)
Local verification checkpoint
Run the smallest example from this chapter in a local workspace and record the package version, runtime, data path, and observed output. If the result depends on model size, vector count, CPU/GPU backend, or available memory, note that constraint beside the exercise so the lesson remains reproducible.
Local verification checkpoint
Run the smallest example from this chapter in a local workspace and record the package version, runtime, data path, and observed output. If the result depends on model size, vector count, CPU/GPU backend, or available memory, note that constraint beside the exercise so the lesson remains reproducible.
Run torchrun --nproc_per_node=2 on a minimal training script that prints GPU rank. Verify all ranks execute and print their rank correctly.