14. Training Loop

Chapter 14 of 24 · 20 min
EXERCISE

: Build a Complete Training Loop

Extend the training loop to include checkpoint saving and early stopping:

def train_with_validation(
    model, train_loader, val_loader, optimizer, scheduler,
    device, epochs=3, checkpoint_dir="./checkpoints"
):
    best_val_loss = float("inf")
    patience = 2
    patience_counter = 0
    
    for epoch in range(epochs):
        train_loss = train_epoch(model, train_loader, optimizer, scheduler, device)
        val_loss = evaluate(model, val_loader, device)
        
        print(f"Epoch {epoch}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            save_checkpoint(model, optimizer, epoch, checkpoint_dir)
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered")
                break
    
    return best_val_loss

Implement the evaluate() and save_checkpoint() functions to complete this pattern.