15. Quantization Accuracy
Chapter 15 of 18 · 15 min
Maintaining accuracy during quantization requires understanding error propagation and applying appropriate calibration strategies.
Error Metrics
import torch
import numpy as np
def quantization_error_metrics(tensor_fp32, tensor_quantized):
"""Compute common error metrics for quantized tensors."""
error = (tensor_fp32 - tensor_quantized.float()).flatten()
metrics = {
'mae': torch.mean(torch.abs(error)).item(),
'mse': torch.mean(error ** 2).item(),
'rmse': torch.sqrt(torch.mean(error ** 2)).item(),
'max_abs_error': torch.max(torch.abs(error)).item(),
'relative_l2': torch.norm(error) / torch.norm(tensor_fp32.flatten()).item()
}
return metrics
def per_channel_error_analysis(tensor_fp32, tensor_quantized, channel_dim=0):
"""Analyze quantization error per channel for activation tensors."""
fp32_per_channel = tensor_fp32.chunk(tensor_fp32.shape[channel_dim], dim=channel_dim)
quant_per_channel = tensor_quantized.chunk(tensor_quantized.shape[channel_dim], dim=channel_dim)
channel_errors = []
for fp32, q in zip(fp32_per_channel, quant_per_channel):
error = torch.mean((fp32 - q.float()) ** 2).item()
channel_errors.append(error)
# Identify problematic channels
threshold = np.mean(channel_errors) + 2 * np.std(channel_errors)
problematic_channels = [i for i, e in enumerate(channel_errors) if e > threshold]
return {
'per_channel_errors': channel_errors,
'problematic_channels': problematic_channels,
'worst_channel': np.argmax(channel_errors)
}
Calibration Strategies
class CalibrationDataset:
def __init__(self, dataset, num_samples=1000):
self.samples = []
self.collected_tensors = {}
def collect(self, name, tensor):
"""Collect tensor statistics during calibration forward pass."""
if name not in self.collected_tensors:
self.collected_tensors[name] = []
self.collected_tensors[name].append(tensor.detach().clone())
def compute_scales(self, strategy='max'):
"""Compute quantization scales from collected data."""
scales = {}
for name, tensors in self.collected_tensors.items():
stacked = torch.stack(tensors).abs()
if strategy == 'max':
scales[name] = stacked.max() / 127.0
elif strategy == 'entropy':
# Minimize KL divergence from uniform distribution
bins = torch.histc(stacked.flatten(), bins=2048)
scales[name] = self.kl_scale_calibration(bins)
elif strategy == 'percentile':
scales[name] = torch.quantile(stacked.flatten(), 0.9999) / 127.0
return scales
@staticmethod
def kl_scale_calibration(bin_counts):
"""Find scale minimizing KL divergence from uniform after quantization."""
total = bin_counts.sum()
pdf = bin_counts.float() / total
for threshold in [0.01, 0.001, 0.0001]:
if pdf.min() < threshold:
pdf = pdf[pdf >= threshold]
# Search over scale values
best_scale = 1.0
best_kl = float('inf')
for scale_exp in range(-10, 10):
scale = 2.0 ** scale_exp
quantized = (torch.arange(len(pdf)) * scale).clamp(max=127.0)
q_pdf = torch.zeros_like(pdf)
for i, q in enumerate(quantized.long()):
q_pdf[q] += pdf[i]
q_pdf = q_pdf / q_pdf.sum()
kl = torch.sum(pdf * torch.log(pdf / (q_pdf + 1e-8)))
if kl < best_kl:
best_kl = kl
best_scale = scale
return best_scale
EXERCISE
Implement cross-layer free quantization (CLQA) where adjacent layer scales are constrained to reduce error accumulation in deep networks.