18. Training Pipeline Project
Chapter 18 of 18 · 25 min
This final chapter integrates everything into a production-grade training pipeline. The goal is not a toy example—it's a system you would actually deploy.
Project Structure
training_pipeline/
├── config/
│ └── default.yaml
├── data/
│ ├── raw/
│ └── processed/
├── models/
├── logs/
├── scripts/
│ ├── preprocess.py
│ ├── train.py
│ ├── evaluate.py
│ └── export.py
├── src/
│ ├── data/
│ │ ├── dataset.py
│ │ └── transforms.py
│ ├── model/
│ │ └── model.py
│ ├── training/
│ │ ├── trainer.py
│ │ ├── optimizers.py
│ │ └── schedulers.py
│ └── utils/
│ ├── checkpoint.py
│ ├── logging.py
│ └── distributed.py
├── tests/
│ ├── test_data.py
│ ├── test_model.py
│ └── test_training.py
├── Snakefile
├── requirements.txt
└── README.md
Complete Configuration
# config/default.yaml
model:
name: "resnet50"
pretrained: true
num_classes: 1000
data:
train_dir: "data/raw/train"
val_dir: "data/raw/val"
processed_dir: "data/processed"
batch_size: 256
num_workers: 8
pin_memory: true
training:
epochs: 100
mixed_precision: true
gradient_clip_norm: 1.0
lr: 0.001
weight_decay: 0.01
warmup_epochs: 5
distributed:
backend: "nccl"
world_size: null # Set by torchrun
checkpointing:
save_dir: "models"
save_every: 10
keep_last_n: 3
logging:
experiment_name: "resnet50-training"
log_every: 50
Full Training Script
#!/usr/bin/env python3
"""Production training script with full observability."""
import argparse
import yaml
from pathlib import Path
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.profiler import profile, ProfilerActivity
from src.data.dataset import ImageDataset
from src.data.transforms import get_transforms
from src.model.model import build_model
from src.training.trainer import Trainer
from src.training.optimizers import build_optimizer
from src.training.schedulers import build_scheduler
from src.utils.checkpoint import CheckpointManager
from src.utils.logging import setup_logging
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=Path, default=Path("config/default.yaml"))
parser.add_argument("--local-rank", type=int, default=None)
return parser.parse_args()
def setup_distributed():
"""Initialize distributed training."""
if "WORLD_SIZE" not in os.environ:
return False
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl")
return True
def main():
args = parse_args()
with open(args.config) as f:
config = yaml.safe_load(f)
is_distributed = setup_distributed()
local_rank = int(os.environ.get("LOCAL_RANK", 0))
is_main = local_rank == 0
if is_main:
setup_logging(config["logging"]["experiment_name"])
# Build model
model = build_model(config["model"])
model = model.cuda()
if is_distributed:
model = DDP(model, device_ids=[local_rank])
# Build data
train_transform = get_transforms("train", config["data"]["image_size"])
val_transform = get_transforms("val", config["data"]["image_size"])
train_dataset = ImageDataset(config["data"]["train_dir"], train_transform)
val_dataset = ImageDataset(config["data"]["val_dir"], val_transform)
train_loader = DataLoader(
train_dataset,
batch_size=config["data"]["batch_size"],
shuffle=not is_distributed, # Sharding handles this
num_workers=config["data"]["num_workers"],
pin_memory=config["data"]["pin_memory"],
drop_last=True
)
val_loader = DataLoader(
val_dataset,
batch_size=config["data"]["batch_size"],
shuffle=False,
num_workers=config["data"]["num_workers"],
pin_memory=config["data"]["pin_memory"]
)
# Build training components
optimizer = build_optimizer(model, config["training"])
scheduler = build_scheduler(optimizer, config["training"], len(train_loader))
checkpoint_manager = CheckpointManager(config["checkpointing"]["save_dir"])
# Mixed precision
scaler = torch.cuda.amp.GradScaler() if config["training"]["mixed_precision"] else None
# Train
trainer = Trainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
optimizer=optimizer,
scheduler=scheduler,
scaler=scaler,
checkpoint_manager=checkpoint_manager,
config=config
)
trainer.train()
if is_distributed:
dist.destroy_process_group()
if __name__ == "__main__":
main()
Launching the Pipeline
# Single node, 4 GPUs
torchrun \
--nproc_per_node=4 \
--nnodes=1 \
train.py \
--config config/default.yaml
# Multi-node
torchrun \
--nproc_per_node=4 \
--nnodes=2 \
--node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR \
--master_port=29500 \
train.py \
--config config/default.yaml
Running Tests
# Data integrity
pytest tests/test_data.py -v
# Model forward pass
pytest tests/test_model.py -v
# Training loop
pytest tests/test_training.py -v -k "test_train_step"
Success Criteria
The pipeline is production-ready when:
- It runs on 4+ GPUs with near-linear scaling
- It recovers from crashes without data loss
- All experiments are logged to MLflow or wandb
- Checkpoints load correctly after code changes
- Tests pass in CI
EXERCISE
Apply this project structure to your current training work. Run the full pipeline on multiple GPUs. Verify checkpoint recovery, experiment logging, and test passing. Fix everything that breaks.