09. 2-bit and 3-bit Quantization

Chapter 9 of 18 · 20 min

Below 4-bit precision, quantization noise becomes increasingly problematic. 2-bit quantization (4 discrete values) and 3-bit quantization (8 discrete values) require specialized techniques to maintain acceptable accuracy. These extreme quantization schemes target microcontroller-class devices where memory reduction outweighs precision requirements.

4-bit quantization typically approximates weights with 16 discrete values using asymmetric quantization ranges:

def asymmetric_quantize(fp_weights, num_bits=4):
    # Compute quantization range per channel
    min_val = np.min(fp_weights, axis=(0, 1, 2), keepdims=True)
    max_val = np.max(fp_weights, axis=(0, 1, 2), keepdims=True)
    
    # Compute scale and zero point
    scale = (max_val - min_val) / (2 ** num_bits - 1)
    zero_point = np.round(-min_val / scale).clip(-8, 7).astype(np.int8)
    
    # Quantize
    quantized = np.round(fp_weights / scale + zero_point).clip(0, 15)
    
    return quantized.astype(np.int8), scale, zero_point

def dequantize(quantized, scale, zero_point):
    return scale * (quantized.astype(np.float32) - zero_point)

Lookup Table (LUT) quantization improves 2-bit accuracy by training codebook values:

import torch
import torch.nn as nn

class LUTQuantizedLinear(nn.Module):
    def __init__(self, original_weight, num_bits=2):
        super().__init__()
        self.num_bits = num_bits
        self.num_codes = 2 ** num_bits
        
        # Initialize codebook
        self.codebook = nn.Parameter(
            torch.randn(self.num_codes, original_weight.shape[1]) * 0.02
        )
        
        # Soft assignment gates (learned)
        self.gates = nn.Parameter(torch.ones(self.num_codes))
        
        # Latent weight for fine-tuning
        self.latent = nn.Parameter(original_weight)
    
    def forward(self, x):
        # Soft selection of codebook entries
        weights = torch.einsum('c,cv->cv', torch.softmax(self.gates, dim=0), self.codebook)
        return torch.nn.functional.linear(x, weights + self.latent * 0.1)

Binary Connect training quantizes weights during forward and backward passes:

def binarize_weights(fp_weights):
    """Straight-through estimator binarization"""
    # Deterministic binarization
    return fp_weights.sign()
    
def binary_connect_training(model, lr=0.001):
    """Forward uses binary, backward uses full precision"""
    for name, param in model.named_parameters():
        if 'weight' in name:
            # Store full precision copy for optimizer
            param.data, param.full_precision = binarize_weights(param.data), param.data
            param.grad = param.grad  # Gradients accumulate in full precision
    
    # Backward pass happens here with param.grad as full precision
    
    for name, param in model.named_parameters():
        if param.full_precision is not None:
            param.data = param.full_precision  # Recover full precision
            param.full_precision = None

Simulated quantization in TensorFlow enables accurate INT4 benchmarking:

import tensorflow as tf

# Use 4-bit simulated quantization during training
class QActivation(tf.layers.Layer):
    def __init__(self, num_bits=4, **kwargs):
        super().__init__(**kwargs)
        self.num_bits = num_bits
        
    def call(self, inputs):
        # Quantize to low precision (simulation only)
        scale = tf.reduce_max(tf.abs(inputs)) / (2 ** (self.num_bits - 1))
        quantized = tf.round(inputs / scale)
        return quantized * scale
EXERCISE

Implement 2-bit asymmetric quantization for a linear layer, measure accuracy degradation on a test dataset, and compare memory footprint reduction versus float32 baseline.