12. Training Arguments
Training arguments control every aspect of the optimization process. Proper configuration balances learning speed, convergence quality, and resource efficiency. The choice of learning rate, batch size, and scheduler type often matters more than modest model or data variations.
Learning rate selection follows different rules for LoRA than for full fine-tuning. LoRA adapters have fewer parameters to optimize and different loss landscapes. Typical learning rates range from 1e-4 to 3e-4, higher than common full fine-tuning rates of 1e-5 to 5e-5. The adapters need sufficient learning signal to modify behavior within limited parameters.
Batch size interacts with gradient accumulation when GPU memory limits direct batching. Smaller batch sizes with more accumulation steps often produce similar final quality to larger batch sizes with less accumulation, provided the learning rate is adjusted proportionally. Total batch size (batch_size × gradient_accumulation_steps) affects optimization dynamics.
Learning rate scheduling warmup helps adapter convergence. Abrupt changes at initialization can destabilize early training. Linear warmup over the first few percent of training steps gradually increases the learning rate before decay begins. Cosine decay then smoothly reduces the rate through training.
Weight decay regularization applies to LoRA parameters differently than to full models. The default weight decay of 0.01 in many configurations may be too aggressive for LoRA adapters. Values of 0.01 to 0.1 often produce better results, or weight decay can be disabled entirely for LoRA parameters.
Logging and evaluation frequency trade off against training speed. Frequent evaluation provides better insight into training dynamics but slows overall training. For initial experiments, logging every 10-50 steps with evaluation every 500-1000 steps balances insight and efficiency.
Checkpointing strategy determines recovery options and storage usage. Saving every epoch provides reasonable coverage for most training runs. The save_total_limit parameter prevents disk accumulation by keeping only the most recent N checkpoints.
Configure a full training arguments setup for a QLoRA training run. Implement learning rate search across a small range and compare results.
# training_arguments_demo.py
from transformers import TrainingArguments
from dataclasses import dataclass
from typing import List, Optional
@dataclass
class TrainingConfig:
"""Production-ready training configuration for LoRA."""
# Model/Dataset paths
model_name: str = "meta-llama/Llama-2-7b-hf"
output_dir: str = "./output"
# LoRA hyperparameters
rank: int = 8
lora_alpha: int = 16
lora_dropout: float = 0.05
# Core training hyperparameters
learning_rate: float = 3e-4
num_epochs: int = 3
per_device_batch_size: int = 4
gradient_accumulation_steps: int = 4
max_grad_norm: float = 0.5
weight_decay: float = 0.01
# Learning rate schedule
warmup_ratio: float = 0.03
lr_scheduler_type: str = "cosine"
# Precision and optimization
bf16: bool = True
fp16: bool = False
optim: str = "paged_adamw_32bit"
# Checkpointing and logging
logging_steps: int = 10
save_strategy: str = "epoch"
save_total_limit: int = 2
eval_strategy: str = "no"
report_to: str = "none"
# Efficiency
gradient_checkpointing: bool = True
gradient_checkpointing_kwargs: dict = None
def to_training_arguments(self) -> TrainingArguments:
"""Convert to Hugging Face TrainingArguments."""
return TrainingArguments(
output_dir=self.output_dir,
num_train_epochs=self.num_epochs,
per_device_train_batch_size=self.per_device_batch_size,
gradient_accumulation_steps=self.gradient_accumulation_steps,
learning_rate=self.learning_rate,
max_grad_norm=self.max_grad_norm,
weight_decay=self.weight_decay,
warmup_ratio=self.warmup_ratio,
lr_scheduler_type=self.lr_scheduler_type,
bf16=self.bf16,
fp16=self.fp16,
optim=self.optim,
logging_steps=self.logging_steps,
save_strategy=self.save_strategy,
save_total_limit=self.save_total_limit,
eval_strategy=self.eval_strategy,
report_to=self.report_to,
gradient_checkpointing=self.gradient_checkpointing,
gradient_checkpointing_kwargs=(
self.gradient_checkpointing_kwargs or
{"use_reentrant": False}
)
)
def create_learning_rate_sweep(
base_config: TrainingConfig,
learning_rates: List[float]
) -> List[TrainingConfig]:
"""Create configs for learning rate sweep."""
configs = []
for lr in learning_rates:
config = TrainingConfig(
model_name=base_config.model_name,
output_dir=f"{base_config.output_dir}_lr{lr}",
learning_rate=lr,
rank=base_config.rank,
lora_alpha=base_config.lora_alpha,
lora_dropout=base_config.lora_dropout,
num_epochs=base_config.num_epochs,
per_device_batch_size=base_config.per_device_batch_size,
gradient_accumulation_steps=base_config.gradient_accumulation_steps
)
configs.append(config)
return configs
# Memory-aware batch size calculator
def calculate_efficient_batch_size(
model_size_b: float,
gpu_memory_gb: float,
seq_length: int = 512,
use_gradient_checkpointing: bool = True
) -> dict:
"""Calculate safe batch size given hardware constraints."""
# Rough memory estimates per sample
base_memory_per_sample_mb = model_size_b * 100 # Rough estimate
# Adjustment for gradient checkpointing
if use_gradient_checkpointing:
base_memory_per_sample_mb *= 0.5
# Adjustment for sequence length (roughly linear)
reference_seq_len = 512
seq_factor = seq_length / reference_seq_len
effective_memory_per_sample = base_memory_per_sample_mb * seq_factor
# Reserve 2GB overhead
available_memory = gpu_memory_gb * 1024 - 2048
# Batch size with some safety margin
max_batch_size = int(available_memory / effective_memory_per_sample * 0.8)
return {
"recommended_batch_size": min(max_batch_size, 16),
"max_batch_size": max_batch_size,
"memory_per_sample_mb": effective_memory_per_sample,
"note": "Adjust based on actual OOM observations"
}
# Example sweep configuration
sweep_config = TrainingConfig(
model_name="meta-llama/Llama-2-7b-hf",
output_dir="./lora_sweep",
learning_rate=3e-4,
rank=8
)
learning_rates = [1e-4, 2e-4, 3e-4, 5e-4]
sweep_configs = create_learning_rate_sweep(sweep_config, learning_rates)
print("Learning rate sweep configurations created:")
for config in sweep_configs:
print(f" {config.output_dir}: lr={config.learning_rate}")