07. 4-bit NormalFloat
Standard integer quantization maps floating-point values to discrete integer representations. For 4-bit quantization, this provides only 16 discrete values to represent the entire range of model weights. The choice of how to distribute these 16 values critically affects quantization accuracy.
Neural network weights typically follow a roughly normal distribution centered near zero. Standard uniform quantization wastes precision where most weights cluster (near zero) to accommodate rare extreme values. The NF4 (4-bit NormalFloat) format addresses this by using non-uniform quantization levels optimized for normal distributions.
NF4 divides the quantization range into 16 levels with spacing determined by the quantile function of the normal distribution. More levels concentrate near zero where the probability density is highest, providing finer precision exactly where it's most valuable. This results in significantly lower quantization error for normally distributed weights compared to uniform quantization.
The NF4 format also includes a block-wise quantization strategy where each block of parameters (typically 64 or 32 parameters) has its own quantization constants. This allows small groups of weights with different ranges to use their local minimum and maximum values for normalization, preventing any single outlier group from forcing coarse quantization across the entire model.
Double quantization extends memory savings by quantizing the quantization constants themselves. Instead of storing block quantization constants in 32-bit floating point, they are further quantized to 8-bit or even 4-bit. The small additional error introduced is negligible compared to the memory saved.
The bitsandbytes library implements NF4 quantization with the necessary dequantization kernels optimized for GPU execution. The CUDA kernels handle on-the-fly dequantization during forward and backward passes with minimal performance overhead compared to the theoretical cost of quantization operations.
Compare quantization error between uniform 4-bit and NF4 for a sample of normally distributed values. Calculate mean squared error for each approach.
# nf4_quantization_demo.py
import torch
import numpy as np
from scipy import stats
def uniform_quantize(tensor: torch.Tensor, bits: int, block_size: int = 64):
"""Quantize tensor using uniform quantization with block-wise scaling."""
num_levels = 2 ** bits
quantized_blocks = []
scales = []
for i in range(0, tensor.numel(), block_size):
block = tensor[i:i+block_size]
if block.numel() == 0:
continue
scale = (block.max() - block.min()) / (num_levels - 1)
if scale > 0:
quantized = torch.round((block - block.min()) / scale)
dequantized = quantized * scale + block.min()
else:
quantized = torch.zeros_like(block)
dequantized = block
quantized_blocks.append(quantized)
scales.append((block.min().item(), scale.item()))
return torch.cat(quantized_blocks), scales
def nf4_quantize(tensor: torch.Tensor, block_size: int = 64):
"""Quantize tensor using NF4 (simplified demonstration)."""
# Generate NF4 levels using quantile function
num_levels = 16
levels = torch.tensor([
stats.norm.ppf((i + 0.5) / num_levels)
for i in range(num_levels)
], dtype=tensor.dtype)
quantized_blocks = []
for i in range(0, tensor.numel(), block_size):
block = tensor[i:i+block_size]
if block.numel() == 0:
continue
# Normalize block to zero mean, unit variance
block_mean = block.mean()
block_std = block.std()
if block_std > 0:
normalized = (block - block_mean) / block_std
else:
normalized = block - block_mean
# Find nearest NF4 level
distances = torch.cdist(normalized.unsqueeze(1), levels.unsqueeze(0))
indices = distances.argmin(dim=1)
# Dequantize
dequantized = levels[indices] * block_std + block_mean
quantized_blocks.append(indices)
return torch.cat(quantized_blocks), levels
# Test on random normal data
torch.manual_seed(42)
test_weights = torch.randn(10000)
# Uniform quantization
uniform_q, _ = uniform_quantize(test_weights, bits=4)
# NF4 quantization
nf4_q, _ = nf4_quantize(test_weights)
print(f"Test data: {test_weights.numel()} parameters")
print(f"Distribution: mean={test_weights.mean():.4f}, std={test_weights.std():.4f}")
# Note: Full NF4 requires specific dequantization for accurate comparison
# This simplified version demonstrates the conceptual difference