14. Experiment Tracking with MLflow

Chapter 14 of 18 · 15 min

Experiment tracking converts "it worked on my machine" into reproducible science. MLflow is the standard open-source solution for tracking experiments.

MLflow Setup

import mlflow
from mlflow.tracking import MlflowClient

# Set tracking URI - local or remote
mlflow.set_tracking_uri("file:///mlruns")
mlflow.set_experiment("my-experiment")

def train_with_mlflow(config):
    with mlflow.start_run(run_name=config.run_name):
        # Log configuration
        mlflow.log_params({
            "lr": config.lr,
            "batch_size": config.batch_size,
            "model": config.model_name,
            "epochs": config.epochs
        })
        
        model = build_model(config)
        optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
        
        for epoch in range(config.epochs):
            train_loss = train_epoch(model, train_loader, optimizer)
            val_loss = validate(model, val_loader)
            
            # Log metrics
            mlflow.log_metrics({
                "train_loss": train_loss,
                "val_loss": val_loss,
                "lr": optimizer.param_groups[0]['lr']
            }, step=epoch)
            
            # Log artifacts
            if val_loss < best_val_loss:
                torch.save(model.state_dict(), "best_model.pt")
                mlflow.log_artifact("best_model.pt")
        
        # Log model
        mlflow.pytorch.log_model(model, "model")

Logging to Remote Server

# Start MLflow server
mlflow server --backend-store-uri postgresql://user:pass@host:5432/mlflow \
              --default-artifact-root s3://my-bucket/mlflow/ \
              --host 0.0.0.0 --port 5000

# Point clients to server
export MLFLOW_TRACKING_URI=http://mlflow-server:5000

Querying Experiments

client = MlflowClient()

# Get best run for an experiment
experiment = client.get_experiment_by_name("my-experiment")
runs = client.search_runs(experiment_ids=[experiment.experiment_id], 
                          order_by=["metrics.val_loss ASC"],
                          max_results=1)

best_run = runs[0]
print(f"Best run: {best_run.info.run_id}")
print(f"Best val_loss: {best_run.data.metrics['val_loss']}")
print(f"Config: {best_run.data.params}")

Local verification checkpoint

Run the smallest example from this chapter in a local workspace and record the package version, runtime, data path, and observed output. If the result depends on model size, vector count, CPU/GPU backend, or available memory, note that constraint beside the exercise so the lesson remains reproducible.

Local verification checkpoint

Run the smallest example from this chapter in a local workspace and record the package version, runtime, data path, and observed output. If the result depends on model size, vector count, CPU/GPU backend, or available memory, note that constraint beside the exercise so the lesson remains reproducible.

EXERCISE

Train 3 different configurations and log each to MLflow. Query the experiment programmatically to find the best configuration.