Training & optimization

Batch Normalization

Batch normalization is a training technique that normalizes the inputs to a layer across a mini-batch of data. It computes the mean and variance of the layer's activations for the batch, then shifts and scales them using learned parameters (gamma and beta). This stabilizes training by reducing internal covariate shift, allowing higher learning rates and faster convergence. During inference, batch normalization uses running averages of mean and variance computed during training, so the operation becomes deterministic. Operators training models with frameworks like PyTorch or TensorFlow will encounter batch normalization as a standard layer (e.g., torch.nn.BatchNorm1d). It is less relevant for inference-only workflows (e.g., running a quantized model via llama.cpp) since the parameters are folded into the model weights.

Deeper dive

Batch normalization works by normalizing each feature dimension to have zero mean and unit variance over the mini-batch. For a layer with output x, the normalized value is (x - μ_batch) / sqrt(σ²_batch + ε), where ε is a small constant for numerical stability. The learned parameters γ and β then scale and shift: y = γ * x_norm + β. During training, the batch statistics are used; during inference, the running averages (tracked during training) replace them. This reduces the dependence on initialization and allows higher learning rates, often speeding training by 2-10x. Variants include layer normalization (normalizes across features per sample), instance normalization (per sample per channel), and group normalization (divides channels into groups). Batch normalization is standard in CNNs (e.g., ResNet) but less common in transformers, which favor layer normalization. For operators, batch normalization layers are typically fused into adjacent convolutional layers during inference for efficiency, especially in quantized models.

Practical example

When training a ResNet-50 image classifier on an RTX 3090 (24 GB VRAM), batch normalization layers allow using a batch size of 128 with a learning rate of 0.1, whereas without BN the same model might require a learning rate of 0.01 and smaller batch sizes to avoid divergence. During inference, frameworks like ONNX Runtime or TensorRT automatically fuse BN layers into preceding conv layers, reducing memory bandwidth and latency.

Workflow example

In a PyTorch training script, batch normalization appears as nn.BatchNorm2d(num_features) in the model definition. During training, the model tracks running mean and variance via model.train(). When exporting to ONNX for inference with onnxruntime, the BN parameters are folded into conv weights if the exporter sets training=False. In llama.cpp, batch normalization is not used in transformer architectures; instead, layer normalization is standard.

Reviewed by Fredoline Eruo. See our editorial policy.