08. Extreme Quantization
Quantization reduces numerical precision from 32-bit floating-point to 8-bit integer representations. This reduction cuts memory footprint by 4x, increases throughput by 2-4x on integer-capable hardware, and reduces power consumption by enabling efficient SIMD operations. Extreme quantization pushes toward 4-bit, 2-bit, or even binary representations.
INT8 quantization via PyTorch requires calibration with representative data:
import torch
from torch.quantization.quantize_fx import prepare_fx, convert_fx
# Prepare model for dynamic quantization
model.eval()
# Quantization configuration
qconfig = torch.quantization.get_default_qconfig('fbgemm')
model.prepared_for_quantization = prepare_fx(model, qconfig, example_inputs)
# Calibration with representative samples
with torch.no_grad():
for batch in calibration_loader:
model(batch)
# Convert to quantized model
quantized_model = convert_fx(model.prepared_for_quantization)
Training-aware quantization (QAT) produces better accuracy for aggressive quantization:
from torch.quantization import QuantStub, DeQuantStub
class QuantizedNet(nn.Module):
def __init__(self):
super().__init__()
self.quant = QuantStub()
self.conv = nn.Conv2d(3, 64, 3, padding=1)
self.relu = nn.ReLU()
self.dequant = DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.conv(x)
x = self.relu(x)
x = self.dequant(x)
return x
Per-channel quantization improves accuracy by quantizing each output channel with its own scale factor:
# In converter
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# Per-channel quantization uses symmetric ranges
def representative_dataset_generator():
for image in dataset:
yield [tf.cast(image, tf.float32)]
converter.representative_dataset = representative_dataset_generator
converter.experimental_enable_mlir_converter = True
tflite_model = converter.convert()
Accuracy degradation manifests differently across tasks. Image classification tolerates INT8 quantization without visible quality loss—output probabilities shift by <1%. Object detection suffers more severely because anchor assignments amplify quantization errors. Generative tasks (style transfer, image enhancement) exhibit banding artifacts at INT8 that INT16 intermediate calculations can mitigate.
Benchmarking quantized models requires comparing both accuracy and latency:
import numpy as np
import time
def benchmark_inference(interpreter, num_runs=100, warmup=10):
# Warmup
for _ in range(warmup):
interpreter.invoke()
# Benchmark
timings = []
for _ in range(num_runs):
start = time.perf_counter()
interpreter.invoke()
timings.append((time.perf_counter() - start) * 1000)
return {
"mean_ms": np.mean(timings),
"std_ms": np.std(timings),
"p95_ms": np.percentile(timings, 95)
}
Quantize a PyTorch model using dynamic quantization, compare inference speed and memory usage against the float32 baseline, and verify output quality degradation.