05. Tensor Parallelism
Tensor parallelism splits individual layer computations across GPUs, enabling models larger than single-GPU memory to execute across a cluster. This technique parallels matrix multiplications across devices, requiring frequent collective communication to synchronize partial results.
The core mechanism operates at the operation level. A matrix multiplication Y = XW splits W into column partitions distributed across GPUs. Each GPU computes its partial output, then participates in an all-reduce to combine results. This pattern repeats for each linear layer in transformer architectures.
Implementation in Megatron-LM-style frameworks divides attention and feedforward layers. Attention splits Q, K, V projections across tensor ranks while keeping output attention combined. Feedforward layers split across the intermediate dimension with similar all-reduce synchronization.
Scaling behavior: tensor parallelism provides linear memory reduction per GPU up to the communication overhead penalty. An 8-GPU tensor-parallel scheme reduces per-device memory by roughly 8x, enabling 8× larger models. Communication overhead from frequent all-reduce calls limits useful tensor parallelism degree—typical sweet spots are 2-8 GPUs before pipeline parallelism becomes necessary.
Local verification checkpoint
Run the smallest example from this chapter in a local workspace and record the package version, runtime, data path, and observed output. If the result depends on model size, vector count, CPU/GPU backend, or available memory, note that constraint beside the exercise so the lesson remains reproducible.
Implement a tensor-parallel linear layer in PyTorch to understand the communication pattern. Distribute a matrix across 2 GPUs, compute partial results independently, then all-reduce the output. Measure communication time versus computation time to understand overhead scaling.
import torch
import torch.distributed as dist
def tensor_parallel_linear(x, weight shards, world_size):
# Each rank computes its partial output
partial_output = torch.matmul(x, local_weight)
# All-reduce to combine results across ranks
output = torch.zeros_like(partial_output)
dist.all_reduce(partial_output, op=dist.ReduceOp.SUM)
return partial_output