15. Quantization Accuracy

Chapter 15 of 18 · 15 min

Maintaining accuracy during quantization requires understanding error propagation and applying appropriate calibration strategies.

Error Metrics

import torch
import numpy as np

def quantization_error_metrics(tensor_fp32, tensor_quantized):
    """Compute common error metrics for quantized tensors."""
    
    error = (tensor_fp32 - tensor_quantized.float()).flatten()
    
    metrics = {
        'mae': torch.mean(torch.abs(error)).item(),
        'mse': torch.mean(error ** 2).item(),
        'rmse': torch.sqrt(torch.mean(error ** 2)).item(),
        'max_abs_error': torch.max(torch.abs(error)).item(),
        'relative_l2': torch.norm(error) / torch.norm(tensor_fp32.flatten()).item()
    }
    
    return metrics

def per_channel_error_analysis(tensor_fp32, tensor_quantized, channel_dim=0):
    """Analyze quantization error per channel for activation tensors."""
    
    fp32_per_channel = tensor_fp32.chunk(tensor_fp32.shape[channel_dim], dim=channel_dim)
    quant_per_channel = tensor_quantized.chunk(tensor_quantized.shape[channel_dim], dim=channel_dim)
    
    channel_errors = []
    for fp32, q in zip(fp32_per_channel, quant_per_channel):
        error = torch.mean((fp32 - q.float()) ** 2).item()
        channel_errors.append(error)
    
    # Identify problematic channels
    threshold = np.mean(channel_errors) + 2 * np.std(channel_errors)
    problematic_channels = [i for i, e in enumerate(channel_errors) if e > threshold]
    
    return {
        'per_channel_errors': channel_errors,
        'problematic_channels': problematic_channels,
        'worst_channel': np.argmax(channel_errors)
    }

Calibration Strategies

class CalibrationDataset:
    def __init__(self, dataset, num_samples=1000):
        self.samples = []
        self.collected_tensors = {}
        
    def collect(self, name, tensor):
        """Collect tensor statistics during calibration forward pass."""
        if name not in self.collected_tensors:
            self.collected_tensors[name] = []
        self.collected_tensors[name].append(tensor.detach().clone())
        
    def compute_scales(self, strategy='max'):
        """Compute quantization scales from collected data."""
        scales = {}
        for name, tensors in self.collected_tensors.items():
            stacked = torch.stack(tensors).abs()
            
            if strategy == 'max':
                scales[name] = stacked.max() / 127.0
            elif strategy == 'entropy':
                # Minimize KL divergence from uniform distribution
                bins = torch.histc(stacked.flatten(), bins=2048)
                scales[name] = self.kl_scale_calibration(bins)
            elif strategy == 'percentile':
                scales[name] = torch.quantile(stacked.flatten(), 0.9999) / 127.0
                
        return scales
    
    @staticmethod
    def kl_scale_calibration(bin_counts):
        """Find scale minimizing KL divergence from uniform after quantization."""
        total = bin_counts.sum()
        pdf = bin_counts.float() / total
        
        for threshold in [0.01, 0.001, 0.0001]:
            if pdf.min() < threshold:
                pdf = pdf[pdf >= threshold]
        
        # Search over scale values
        best_scale = 1.0
        best_kl = float('inf')
        
        for scale_exp in range(-10, 10):
            scale = 2.0 ** scale_exp
            quantized = (torch.arange(len(pdf)) * scale).clamp(max=127.0)
            q_pdf = torch.zeros_like(pdf)
            for i, q in enumerate(quantized.long()):
                q_pdf[q] += pdf[i]
            q_pdf = q_pdf / q_pdf.sum()
            
            kl = torch.sum(pdf * torch.log(pdf / (q_pdf + 1e-8)))
            if kl < best_kl:
                best_kl = kl
                best_scale = scale
                
        return best_scale
EXERCISE

Implement cross-layer free quantization (CLQA) where adjacent layer scales are constrained to reduce error accumulation in deep networks.