16. Checkpointing
Chapter 16 of 18 · 20 min
Training a model for 72 hours without checkpointing is a career-limiting decision. Proper checkpointing enables recovery from failures, evaluation of intermediate checkpoints, and ensemble creation.
Checkpoint Structure
import torch
import os
from pathlib import Path
def save_checkpoint(model, optimizer, scheduler, epoch, val_loss, path):
"""Save training state for recovery."""
checkpoint = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict() if scheduler else None,
"val_loss": val_loss,
"config": vars(model.cfg) # Serialize config
}
# Atomic write prevents corruption
tmp_path = path + ".tmp"
torch.save(checkpoint, tmp_path)
os.replace(tmp_path, path)
print(f"Saved checkpoint to {path}")
def load_checkpoint(path, model, optimizer=None, scheduler=None):
"""Restore training state."""
checkpoint = torch.load(path, map_location="cpu")
model.load_state_dict(checkpoint["model_state_dict"])
if optimizer and "optimizer_state_dict" in checkpoint:
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if scheduler and "scheduler_state_dict" in checkpoint:
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
return checkpoint["epoch"], checkpoint["val_loss"]
Checkpoint Strategy
class CheckpointManager:
def __init__(self, checkpoint_dir, max_checkpoints=5):
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.max_checkpoints = max_checkpoints
self.checkpoints = []
def save_if_best(self, model, optimizer, epoch, val_loss):
is_best = not self.checkpoints or val_loss < min(self.checkpoints)[0]
path = self.checkpoint_dir / f"checkpoint_epoch_{epoch}.pt"
save_checkpoint(model, optimizer, None, epoch, val_loss, path)
self.checkpoints.append((val_loss, path))
self.checkpoints.sort()
# Keep best + recent checkpoints
if len(self.checkpoints) > self.max_checkpoints:
_, old_path = self.checkpoints.pop(0)
if old_path.exists():
old_path.unlink()
# Always keep best
best_path = self.checkpoint_dir / "best.pt"
torch.save(torch.load(path), best_path)
Distributed Checkpointing
# FSDP-compatible checkpointing
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
def save_fsdp_checkpoint(model, optimizer, path):
# FSDP requires FSDP-specific state saving
model.save_checkpoint(path)
# Internally handles sharded parameters
def load_fsdp_checkpoint(model, path):
model.load_checkpoint(path)
Local verification checkpoint
Run the smallest example from this chapter in a local workspace and record the package version, runtime, data path, and observed output. If the result depends on model size, vector count, CPU/GPU backend, or available memory, note that constraint beside the exercise so the lesson remains reproducible.
EXERCISE
Simulate a training interruption mid-epoch by stopping the script. Verify you can resume from the last checkpoint with identical loss trajectory.