How to set up GPU memory optimization for inference
NVIDIA GPU, model to optimize, nvidia-smi
What this does
Running large language models on limited VRAM requires aggressive memory optimization. This guide covers quantization with bitsandbytes, KV-cache tuning, and dynamic batch size scheduling to maximize throughput within available GPU memory.
Steps
Step 1: Verify GPU availability
nvidia-smi
Confirm: GPU name and compute capability, total VRAM (e.g., 24 GB), current memory utilization. Note the driver and CUDA version shown at the top.
Step 2: Install optimization dependencies
pip install torch --index-url https://download.pytorch.org/whl/cu121
pip install bitsandbytes accelerate
pip install transformers
bitsandbytes enables INT8 and NF4 quantization. accelerate provides memory-efficient model loading utilities. transformers provides the model infrastructure.
Step 3: Build the optimization manager
import torch
class GPUOptimizationManager:
"""Manage GPU memory for inference workloads."""
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if self.device.type != "cuda":
raise RuntimeError("CUDA GPU not available. This guide requires an NVIDIA GPU.")
def get_memory_stats(self) -> dict:
"""Return current GPU memory usage."""
allocated = torch.cuda.memory_allocated(self.device) / 1e9
reserved = torch.cuda.memory_reserved(self.device) / 1e9
total = torch.cuda.get_device_properties(self.device).total_memory / 1e9
return {
"allocated_gb": round(allocated, 2),
"reserved_gb": round(reserved, 2),
"total_gb": round(total, 2),
"free_gb": round(total - allocated, 2)
}
def clear_cache(self):
"""Free reserved but unused GPU memory."""
torch.cuda.empty_cache()
print(f"[GPU] Cache cleared. Stats: {self.get_memory_stats()}")
Step 4: Implement quantization
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
def load_quantized_model(model_name: str, bits: int = 4):
"""
Load model with INT8 or 4-bit quantization.
4-bit (NF4) reduces memory by ~75% with minimal quality loss.
"""
if bits == 4:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
elif bits == 8:
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_has_fp16_weight=False
)
else:
quantization_config = None
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
return model, tokenizer
Step 5: Implement KV-cache tuning and dynamic batch scheduling
from transformers import GenerationConfig
def configure_kv_cache_optimization(model, max_batch_size: int = 8):
"""Enable paged attention and configure KV cache for memory efficiency."""
generation_config = GenerationConfig(
use_cache=True,
do_sample=True,
temperature=0.7
)
model.generation_config = generation_config
print(f"[KV-Cache] Enabled. Max batch size: {max_batch_size}")
return generation_config
class DynamicBatchScheduler:
"""Schedule inference batches based on available VRAM."""
def __init__(self, model, model_memory_per_token_bytes: float = 2000):
self.model = model
self.model_memory_per_token = model_memory_per_token_bytes
self.manager = GPUOptimizationManager()
def compute_max_batch_size(self, sequence_length: int = 512) -> int:
"""Calculate maximum batch size given available VRAM."""
stats = self.manager.get_memory_stats()
free_gb = stats["free_gb"]
memory_budget_gb = free_gb - 1.0 # Reserve 1 GB safety margin
if memory_budget_gb <= 0:
return 0
memory_bytes = memory_budget_gb * 1e9
memory_per_sample = sequence_length * self.model_memory_per_token
max_batch = int(memory_bytes / memory_per_sample)
return max(1, min(max_batch, 8))
def schedule(self, pending_requests: int, sequence_length: int = 512) -> dict:
max_batch = self.compute_max_batch_size(sequence_length)
if max_batch == 0:
return {"action": "defer", "reason": "insufficient_memory", "pending": pending_requests}
batch_size = min(pending_requests, max_batch)
return {"action": "process", "batch_size": batch_size, "deferred": pending_requests - batch_size}
Step 6: Run optimized inference
# Example: Load a 7B model in 4-bit and run inference
try:
model, tokenizer = load_quantized_model("microsoft/phi-2", bits=4)
scheduler = DynamicBatchScheduler(model)
print(f"Memory stats: {scheduler.manager.get_memory_stats()}")
prompt = "Explain quantum entanglement in simple terms."
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(
**inputs,
max_new_tokens=100,
temperature=0.7
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Response: {response}")
print(f"Memory after inference: {scheduler.manager.get_memory_stats()}")
except Exception as e:
print(f"Optimization error: {e}")
Verification
Run the script and verify:
torch.cuda.is_available()returnsTrue.- Model loads successfully with
BitsAndBytesConfig(4-bit or 8-bit). get_memory_stats()shows allocated memory significantly lower than the unquantized baseline (target: ~4 GB for a 7B model at 4-bit instead of ~14+ GB at float16).compute_max_batch_size()returns a positive integer given available VRAM.schedule()returns"action": "process"for reasonable request counts.
Common failures
bitsandbytes installation issues.
bitsandbytesrequires specific CUDA architecture support and a matching PyTorch version. If quantization silently falls back to full precision, verify the package is imported correctly:import bitsandbytes; print(bitsandbytes.__version__)and confirm the CUDA version matches.CUDA version mismatch. Mixing PyTorch built for CUDA 12.1 with a driver built for CUDA 11.8 causes silent failures and zeros output. Run
python -c "import torch; print(torch.version.cuda)"and compare it to the driver CUDA version fromnvidia-smi.KV-cache not being reused across requests. With
device_map="auto", the default model placement may create multiple device maps that don't share cache efficiently. Setdevice_map="auto"withmax_memoryexplicitly to ensure the same device handles all requests.
- Version mismatch - The installed package or runtime differs from the command shown; check the version first and rerun the smallest verification command.
- Local environment drift - Another service, virtual environment, model, or path is being used; print the active binary path and configuration before changing the guide steps.
Related guides
- Implement Parallel Agent Execution - Use GPU memory optimization to enable running multiple model instances concurrently on the same GPU.
- Implement Streaming Responses in AI APIs - GPU optimization directly reduces per-token generation time, improving streaming latency.