06. Slow Inference
Diagnosing Bottlenecks
Slow inference has four common causes: GPU underutilization, CPU-GPU transfer overhead, small batch sizes, or the model being too large for available memory causing swapping.
# Monitor GPU utilization during inference
nvidia-smi --query-gpu=utilization.gpu,utilization.memory,temperature.gpu \
--format=csv -l 1
Low GPU utilization (<50%) with high memory usage = GPU starvation from data transfer overhead.
High GPU utilization (95%+) with low memory usage = model is compute-bound. Use quantization or a more efficient architecture.
GPU utilization near 100% with memory near 100% = VRAM exhausted, system swapping to CPU RAM. Reduce model size or batch size.
Profiling with PyTorch
import torch
from transformers import AutoModelForCausalLM
import time
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
device_map="auto",
torch_dtype=torch.float16
)
# Warmup
input_ids = torch.randint(0, 32000, (1, 128)).cuda()
for _ in range(3):
model.generate(input_ids, max_new_tokens=20)
# Time generation
torch.cuda.synchronize()
start = time.time()
output = model.generate(input_ids, max_new_tokens=100)
torch.cuda.synchronize()
elapsed = time.time() - start
print(f"Tokens generated: {output.shape[1] - input_ids.shape[1]}")
print(f"Time: {elapsed:.2f}s")
print(f"Tokens/second: {(output.shape[1] - input_ids.shape[1]) / elapsed:.1f}")
Common Fixes
| Bottleneck | Fix |
|---|---|
| Slow token generation | Use batch processing, enable KV cache |
| CPU-GPU transfer | Pre-tokenize inputs, use pinned memory |
| Small context processing | Use Flash Attention (requires compatible GPU and software) |
| Memory-bound inference | Quantize (GPTQ, AWQ, GGUF) |
Local verification checkpoint
Run the smallest example from this chapter in a local workspace and record the package version, runtime, data path, and observed output. If the result depends on model size, vector count, CPU/GPU backend, or available memory, note that constraint beside the exercise so the lesson remains reproducible.
Profile inference latency with your current setup. Measure tokens/second for a fixed prompt. Try enabling KV cache (use_cache=True in generation) and measure again. Document the speedup.