07. Mixed Precision
Mixed precision inference operates different components of a model at different numerical precisions, optimizing the accuracy-speed-memory tradeoff by reserving high precision for sensitive operations while accelerating others through aggressive quantization.
Modern GPU architectures excel at mixed precision computation through tensor cores that perform matrix multiplications in float16 or bfloat16 while supporting accumulation in float32, preserving numerical stability without sacrificing throughput. Local AI inference extends this principle to int8 and lower precision where hardware supports it.
import torch
from torch.amp import autocast
class MixedPrecisionConfig:
"""Configure mixed precision strategy for model layers."""
def __init__(self):
self.layer_configs = {}
def set_precision(self, layer_pattern, precision):
"""
Set precision for layers matching pattern.
Args:
layer_pattern: regex pattern or layer name
precision: 'fp16', 'bf16', 'int8', 'int4', 'fp32'
"""
self.layer_configs[layer_pattern] = precision
def get_precision(self, layer_name):
"""Get configured precision for a specific layer."""
import re
for pattern, precision in self.layer_configs.items():
if re.match(pattern, layer_name) or pattern in layer_name:
return precision
return 'int8' # default fallback
def apply_mixed_precision(model, config):
"""Wrap model layers with configured precision."""
for name, module in model.named_modules():
precision = config.get_precision(name)
if precision == 'int8':
module = torch.ao.quantization.quantize_dynamic(
module, dtype=torch.qint8
)
elif precision in ('fp16', 'bf16'):
module = module.to(dtype={
'fp16': torch.float16,
'bf16': torch.bfloat16
}[precision])
# Replace module in parent
parent_name, child_name = name.rsplit('.', 1) if '.' in name else ('', name)
if parent_name:
setattr(model.get_submodule(parent_name), child_name, module)
return model
Identifying which layers require higher precision involves sensitivity analysis. Gradient-based methods measure how much each layer contributes to final loss gradients, marking sensitive layers for preservation. Forward-based methods compare outputs against full-precision baselines, quantifying deviation introduced by quantization per layer.
def analyze_layer_sensitivity(model, reference_inputs, layer_names=None):
"""
Analyze quantization sensitivity for each layer.
Returns dictionary mapping layer names to sensitivity scores.
"""
model.eval()
sensitivity = {}
if layer_names is None:
layer_names = [n for n, _ in model.named_modules()
if isinstance(_, (nn.Linear, nn.Conv2d))]
# Get reference outputs
with torch.no_grad():
ref_outputs = {}
hooks = []
def get_output(name):
def hook_fn(module, input, output):
ref_outputs[name] = output.clone()
return hook_fn
for name in layer_names:
m = model.get_submodule(name)
hooks.append(m.register_forward_hook(get_output(name)))
ref_output = model(reference_inputs)
for hook in hooks:
hook.remove()
# Measure sensitivity through layerwise quantization
for name in layer_names:
m = model.get_submodule(name)
# Quantize only this layer
m_orig = m.weight.clone()
with torch.ao.quantization.quantize_dynamic(
module, dtype=torch.qint8, inplace=True
):
q_output = model(reference_inputs)
# Restore and measure distance
m.weight = m_orig
distance = torch.dist(ref_output, q_output)
sensitivity[name] = distance.item()
return sensitivity
Practical mixed precision implementations exploitkv-cache quantization. The key-value cache in attention mechanisms grows linearly with sequence length, becoming memory-dominant for long contexts. Quantizing the cache to int4 or int8 dramatically reduces memory requirements with minimal impact on output quality due to the cached nature of this data.
def kv_cache_quantization(k_cache, v_cache, quant_bits=8):
"""
Apply asymmetric quantization to key-value caches.
Returns quantized caches with per-head scale factors.
"""
num_heads = k_cache.shape[1]
cache_shape = k_cache.shape # [batch, heads, seq, dim]
qk_cache = torch.empty_like(k_cache, dtype=torch.int8)
qv_cache = torch.empty_like(v_cache, dtype=torch.int8)
k_scales = torch.zeros(num_heads)
v_scales = torch.zeros(num_heads)
for head in range(num_heads):
k_head = k_cache[:, head, :, :].reshape(-1)
v_head = v_cache[:, head, :, :].reshape(-1)
qk_cache[:, head, :, :], scale_k = asymmetric_quantize(
k_head, quant_bits, return_scale=True
)
qv_cache[:, head, :, :], scale_v = asymmetric_quantize(
v_head, quant_bits, return_scale=True
)
k_scales[head] = scale_k
v_scales[head] = scale_v
return qk_cache, k_scales, qv_cache, v_scales
Implement a mixed precision strategy for a transformer model that uses float16 for attention projections, int8 for feed-forward layers, and int4 for embedding tables. Benchmark memory savings and throughput improvement against full-int8 baseline while measuring accuracy degradation.