09. 2-bit and 3-bit Quantization
Below 4-bit precision, quantization noise becomes increasingly problematic. 2-bit quantization (4 discrete values) and 3-bit quantization (8 discrete values) require specialized techniques to maintain acceptable accuracy. These extreme quantization schemes target microcontroller-class devices where memory reduction outweighs precision requirements.
4-bit quantization typically approximates weights with 16 discrete values using asymmetric quantization ranges:
def asymmetric_quantize(fp_weights, num_bits=4):
# Compute quantization range per channel
min_val = np.min(fp_weights, axis=(0, 1, 2), keepdims=True)
max_val = np.max(fp_weights, axis=(0, 1, 2), keepdims=True)
# Compute scale and zero point
scale = (max_val - min_val) / (2 ** num_bits - 1)
zero_point = np.round(-min_val / scale).clip(-8, 7).astype(np.int8)
# Quantize
quantized = np.round(fp_weights / scale + zero_point).clip(0, 15)
return quantized.astype(np.int8), scale, zero_point
def dequantize(quantized, scale, zero_point):
return scale * (quantized.astype(np.float32) - zero_point)
Lookup Table (LUT) quantization improves 2-bit accuracy by training codebook values:
import torch
import torch.nn as nn
class LUTQuantizedLinear(nn.Module):
def __init__(self, original_weight, num_bits=2):
super().__init__()
self.num_bits = num_bits
self.num_codes = 2 ** num_bits
# Initialize codebook
self.codebook = nn.Parameter(
torch.randn(self.num_codes, original_weight.shape[1]) * 0.02
)
# Soft assignment gates (learned)
self.gates = nn.Parameter(torch.ones(self.num_codes))
# Latent weight for fine-tuning
self.latent = nn.Parameter(original_weight)
def forward(self, x):
# Soft selection of codebook entries
weights = torch.einsum('c,cv->cv', torch.softmax(self.gates, dim=0), self.codebook)
return torch.nn.functional.linear(x, weights + self.latent * 0.1)
Binary Connect training quantizes weights during forward and backward passes:
def binarize_weights(fp_weights):
"""Straight-through estimator binarization"""
# Deterministic binarization
return fp_weights.sign()
def binary_connect_training(model, lr=0.001):
"""Forward uses binary, backward uses full precision"""
for name, param in model.named_parameters():
if 'weight' in name:
# Store full precision copy for optimizer
param.data, param.full_precision = binarize_weights(param.data), param.data
param.grad = param.grad # Gradients accumulate in full precision
# Backward pass happens here with param.grad as full precision
for name, param in model.named_parameters():
if param.full_precision is not None:
param.data = param.full_precision # Recover full precision
param.full_precision = None
Simulated quantization in TensorFlow enables accurate INT4 benchmarking:
import tensorflow as tf
# Use 4-bit simulated quantization during training
class QActivation(tf.layers.Layer):
def __init__(self, num_bits=4, **kwargs):
super().__init__(**kwargs)
self.num_bits = num_bits
def call(self, inputs):
# Quantize to low precision (simulation only)
scale = tf.reduce_max(tf.abs(inputs)) / (2 ** (self.num_bits - 1))
quantized = tf.round(inputs / scale)
return quantized * scale
Implement 2-bit asymmetric quantization for a linear layer, measure accuracy degradation on a test dataset, and compare memory footprint reduction versus float32 baseline.