10. Custom Training Loop

Chapter 10 of 18 · 20 min

The training loop is where the engineering decisions compound. A poorly structured loop hides bugs, makes profiling impossible, and creates production incidents.

Minimal Loop

def training_loop(model, train_loader, val_loader, optimizer, scheduler, config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    best_val_loss = float("inf")
    
    for epoch in range(config.epochs):
        # Training phase
        model.train()
        train_loss = 0
        for step, batch in enumerate(train_loader):
            inputs = batch["input"].to(device)
            targets = batch["target"].to(device)
            
            optimizer.zero_grad(set_to_none=True)  # Faster than .zero_grad()
            
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, targets)
            loss.backward()
            
            # Gradient clipping before optimizer step
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            train_loss += loss.item()
            
            # Log every 100 steps
            if step % 100 == 0:
                print(f"Epoch {epoch} Step {step}: loss={loss.item():.4f}")
        
        # Validation phase
        val_loss = validate(model, val_loader, device)
        
        # LR scheduling after epoch or per-step
        scheduler.step(val_loss)
        
        # Checkpointing
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_checkpoint(model, optimizer, epoch, val_loss, "best.pt")
        
        # Periodic checkpointing
        if epoch % 10 == 0:
            save_checkpoint(model, optimizer, epoch, val_loss, f"epoch_{epoch}.pt")
    
    return model

Gradient Clipping

Gradient explosions are common in RNNs, transformers with high learning rates, and mixed-precision training. Always clip:

torch.nn.utils.clip_grad_norm_(
    model.parameters(),
    max_norm=1.0,  # Global norm, not per-parameter
    norm_type=2.0
)

set_to_none for Performance

optimizer.zero_grad(set_to_none=True)  # ~10% faster than zero_grad()

This sets gradients to None instead of zero, which reduces memory traffic. Some custom backward functions don't support None gradients—test your model.

Local verification checkpoint

Run the smallest example from this chapter in a local workspace and record the package version, runtime, data path, and observed output. If the result depends on model size, vector count, CPU/GPU backend, or available memory, note that constraint beside the exercise so the lesson remains reproducible.

EXERCISE

Implement the minimal training loop above with your model. Add torch.cuda.synchronize() before and after the forward pass to measure GPU time vs. total time.