07. Mixed Precision

Chapter 7 of 18 · 20 min

Mixed precision inference operates different components of a model at different numerical precisions, optimizing the accuracy-speed-memory tradeoff by reserving high precision for sensitive operations while accelerating others through aggressive quantization.

Modern GPU architectures excel at mixed precision computation through tensor cores that perform matrix multiplications in float16 or bfloat16 while supporting accumulation in float32, preserving numerical stability without sacrificing throughput. Local AI inference extends this principle to int8 and lower precision where hardware supports it.

import torch
from torch.amp import autocast

class MixedPrecisionConfig:
    """Configure mixed precision strategy for model layers."""
    
    def __init__(self):
        self.layer_configs = {}
    
    def set_precision(self, layer_pattern, precision):
        """
        Set precision for layers matching pattern.
        
        Args:
            layer_pattern: regex pattern or layer name
            precision: 'fp16', 'bf16', 'int8', 'int4', 'fp32'
        """
        self.layer_configs[layer_pattern] = precision
        
    def get_precision(self, layer_name):
        """Get configured precision for a specific layer."""
        import re
        for pattern, precision in self.layer_configs.items():
            if re.match(pattern, layer_name) or pattern in layer_name:
                return precision
        return 'int8'  # default fallback
    
def apply_mixed_precision(model, config):
    """Wrap model layers with configured precision."""
    
    for name, module in model.named_modules():
        precision = config.get_precision(name)
        
        if precision == 'int8':
            module = torch.ao.quantization.quantize_dynamic(
                module, dtype=torch.qint8
            )
        elif precision in ('fp16', 'bf16'):
            module = module.to(dtype={
                'fp16': torch.float16,
                'bf16': torch.bfloat16
            }[precision])
        
        # Replace module in parent
        parent_name, child_name = name.rsplit('.', 1) if '.' in name else ('', name)
        if parent_name:
            setattr(model.get_submodule(parent_name), child_name, module)
    
    return model

Identifying which layers require higher precision involves sensitivity analysis. Gradient-based methods measure how much each layer contributes to final loss gradients, marking sensitive layers for preservation. Forward-based methods compare outputs against full-precision baselines, quantifying deviation introduced by quantization per layer.

def analyze_layer_sensitivity(model, reference_inputs, layer_names=None):
    """
    Analyze quantization sensitivity for each layer.
    Returns dictionary mapping layer names to sensitivity scores.
    """
    model.eval()
    sensitivity = {}
    
    if layer_names is None:
        layer_names = [n for n, _ in model.named_modules() 
                       if isinstance(_, (nn.Linear, nn.Conv2d))]
    
    # Get reference outputs
    with torch.no_grad():
        ref_outputs = {}
        hooks = []
        
        def get_output(name):
            def hook_fn(module, input, output):
                ref_outputs[name] = output.clone()
            return hook_fn
        
        for name in layer_names:
            m = model.get_submodule(name)
            hooks.append(m.register_forward_hook(get_output(name)))
        
        ref_output = model(reference_inputs)
        
        for hook in hooks:
            hook.remove()
    
    # Measure sensitivity through layerwise quantization
    for name in layer_names:
        m = model.get_submodule(name)
        
        # Quantize only this layer
        m_orig = m.weight.clone()
        with torch.ao.quantization.quantize_dynamic(
            module, dtype=torch.qint8, inplace=True
        ):
            q_output = model(reference_inputs)
        
        # Restore and measure distance
        m.weight = m_orig
        distance = torch.dist(ref_output, q_output)
        sensitivity[name] = distance.item()
    
    return sensitivity

Practical mixed precision implementations exploitkv-cache quantization. The key-value cache in attention mechanisms grows linearly with sequence length, becoming memory-dominant for long contexts. Quantizing the cache to int4 or int8 dramatically reduces memory requirements with minimal impact on output quality due to the cached nature of this data.

def kv_cache_quantization(k_cache, v_cache, quant_bits=8):
    """
    Apply asymmetric quantization to key-value caches.
    Returns quantized caches with per-head scale factors.
    """
    num_heads = k_cache.shape[1]
    cache_shape = k_cache.shape  # [batch, heads, seq, dim]
    
    qk_cache = torch.empty_like(k_cache, dtype=torch.int8)
    qv_cache = torch.empty_like(v_cache, dtype=torch.int8)
    k_scales = torch.zeros(num_heads)
    v_scales = torch.zeros(num_heads)
    
    for head in range(num_heads):
        k_head = k_cache[:, head, :, :].reshape(-1)
        v_head = v_cache[:, head, :, :].reshape(-1)
        
        qk_cache[:, head, :, :], scale_k = asymmetric_quantize(
            k_head, quant_bits, return_scale=True
        )
        qv_cache[:, head, :, :], scale_v = asymmetric_quantize(
            v_head, quant_bits, return_scale=True
        )
        
        k_scales[head] = scale_k
        v_scales[head] = scale_v
    
    return qk_cache, k_scales, qv_cache, v_scales
EXERCISE

Implement a mixed precision strategy for a transformer model that uses float16 for attention projections, int8 for feed-forward layers, and int4 for embedding tables. Benchmark memory savings and throughput improvement against full-int8 baseline while measuring accuracy degradation.