17. Pipeline Orchestration
Training pipelines rarely run as single scripts. Real workflows involve data preprocessing, distributed training, evaluation, model registration, and notification—pipeline orchestration manages these dependencies.
Simple Orchestration with Snakemake
# Snakefile
rule all:
input: "models/final_model.pt", "reports/metrics.json"
rule preprocess:
output: "data/processed/train.pt", "data/processed/val.pt"
shell: "python scripts/preprocess.py --output data/processed"
rule train:
input: "data/processed/train.pt", "data/processed/val.pt"
output: "models/best.pt"
shell: "torchrun --nproc_per_node=4 train.py --epochs 100"
rule evaluate:
input: "models/best.pt", "data/processed/val.pt"
output: "reports/metrics.json"
shell: "python scripts/evaluate.py --model models/best.pt --output reports/metrics.json"
rule export:
input: "models/best.pt"
output: "models/final_model.pt"
shell: "python scripts/export.py --input models/best.pt --output models/final_model.pt"
snakemake --jobs 4 --resources gpu=4
Airflow for Complex Workflows
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.operators.docker_operator import DockerOperator
from datetime import datetime, timedelta
default_args = {
'owner': 'ml-team',
'depends_on_past': False,
'start_date': datetime(2024, 1, 1),
'retries': 2,
'retry_delay': timedelta(minutes=5)
}
dag = DAG('training_pipeline', default_args=default_args, schedule_interval='@daily')
t1 = PythonOperator(
task_id='preprocess',
python_callable=preprocess_data,
dag=dag
)
t2 = DockerOperator(
task_id='train',
image='ml-training:latest',
command='torchrun --nproc_per_node=4 /app/train.py',
environment={'NVIDIA_VISIBLE_DEVICES': 'all'},
dag=dag
)
t3 = PythonOperator(
task_id='evaluate',
python_callable=evaluate_model,
dag=dag
)
t1 >> t2 >> t3
Failure Recovery
# Idempotent preprocessing with checksums
import hashlib
def preprocess_shard(shard_path, output_path):
if output_path.exists():
existing_hash = hashlib.md5(output_path.read_bytes()).hexdigest()
new_hash = hashlib.md5(shard_path.read_bytes()).hexdigest()
if existing_hash == new_hash:
print(f"Skipping {shard_path} (already processed)")
return
# Process and write
Local verification checkpoint
Run the smallest example from this chapter in a local workspace and record the package version, runtime, data path, and observed output. If the result depends on model size, vector count, CPU/GPU backend, or available memory, note that constraint beside the exercise so the lesson remains reproducible.
Local verification checkpoint
Run the smallest example from this chapter in a local workspace and record the package version, runtime, data path, and observed output. If the result depends on model size, vector count, CPU/GPU backend, or available memory, note that constraint beside the exercise so the lesson remains reproducible.
Implement your current training workflow as a Snakemake or Airflow DAG with at least 3 stages (preprocess → train → evaluate).