04. Calibration Datasets

Chapter 4 of 18 · 20 min

Calibration datasets determine quantization parameters by providing representative input samples that establish activation value ranges. The quality of this calibration directly influences quantized model accuracy, making dataset selection a critical step in the quantization workflow.

Effective calibration requires samples drawn from the same distribution as inference inputs. For a language model, this means calibration data should resemble the text corpus the model will process—domain-matched, length-distribution similar, and vocabulary-aligned with production workloads.

import torch
from torch.utils.data import Dataset

class CalibrationDataset(Dataset):
    """Prepare calibration dataset from raw inputs."""
    
    def __init__(self, data samples, tokenizer, max_length=2048, nsamples=512):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.nsamples = nsamples
        self.data = self._prepare_samples(data_samples)
    
    def _prepare_samples(self, raw_data):
        """
        Select representative samples for calibration.
        Strategy: Random selection with length stratification.
        """
        samples = []
        
        # Group samples by length buckets
        length_buckets = {64: [], 128: [], 256: [], 512: [], 1024: [], 2048: []}
        
        for sample in raw_data:
            tokens = self.tokenizer.encode(sample)
            length = len(tokens)
            
            # Assign to nearest bucket
            for bucket_len in sorted(length_buckets.keys()):
                if length <= bucket_len:
                    length_buckets[bucket_len].append(tokens)
                    break
        
        # Stratified sampling across buckets
        samples_per_bucket = self.nsamples // len(length_buckets)
        
        for bucket_samples in length_buckets.values():
            if bucket_samples:
                indices = torch.randperm(len(bucket_samples))[:samples_per_bucket]
                samples.extend([bucket_samples[i] for i in indices])
        
        return samples[:self.nsamples]
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return torch.tensor(self.data[idx], dtype=torch.long)

The recommended calibration size depends on model architecture and quantization granularity. For int8 quantization, 512 to 1024 samples typically suffices for per-tensor schemes. Per-channel quantization may require larger datasets to capture channel-wise variation accurately. GPTQ-style weight quantization needs even more samples due to the Hessian computation involved.

Dataset preparation involves tokenization with identical processing as production inference—same preprocessing, same vocabulary, same input formatting. Any discrepancy between calibration and inference conditions manifests as suboptimal quantization parameters and accuracy degradation.

def collect_activation_stats(model, calibration_loader, percentile=99.9):
    """
    Collect activation statistics across calibration samples.
    Returns per-layer min/max values at specified percentile.
    """
    activation_stats = {}
    
    model.eval()
    hooks = []
    
    def capture_activation(name):
        def hook_fn(module, input, output):
            act = output[0] if isinstance(output, tuple) else output
            # Capture per-channel statistics
            if name not in activation_stats:
                activation_stats[name] = {
                    'min': act.abs().amin(dim=tuple(range(1, act.dim()))),
                    'max': act.abs().amax(dim=tuple(range(1, act.dim()))),
                    'count': 1
                }
            else:
                activation_stats[name]['min'] = torch.minimum(
                    activation_stats[name]['min'],
                    act.abs().amin(dim=tuple(range(1, act.dim())))
                )
                activation_stats[name]['max'] = torch.maximum(
                    activation_stats[name]['max'],
                    act.abs().amax(dim=tuple(range(1, act.dim())))
                )
                activation_stats[name]['count'] += 1
        return hook_fn
    
    # Register hooks for quantizable layers
    for name, module in model.named_modules():
        if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
            hooks.append(module.register_forward_hook(capture_activation(name)))
    
    # Run calibration forward passes
    with torch.no_grad():
        for batch in calibration_loader:
            model(batch)
    
    # Remove hooks
    for hook in hooks:
        hook.remove()
    
    return activation_stats

Beyond random stratified sampling, advanced calibration strategies include hot-flip calibration that measures per-parameter sensitivity, Hessian-based methods for optimal rounding, and reinforcement learning approaches that adaptively select samples maximizing coverage of activation space.

EXERCISE

Implement an automated calibration dataset selector that clusters raw samples by embedding similarity and selects representatives from each cluster. Compare the resulting quantization accuracy against random sampling on a small test model.