09. FSDP
Fully Sharded Data Parallelism (FSDP) shards model parameters across GPUs, dramatically reducing memory usage per GPU and enabling training of models that would otherwise require model parallelism.
FSDP vs DDP
DDP replicates the full model on each GPU—parameter size per GPU equals model size. FSDP shards parameters—parameter size per GPU equals model size divided by GPU count.
For a 7B parameter model (FP16, ~14GB):
- DDP on 8 GPUs: 14GB per GPU
- FSDP on 8 GPUs: 14GB / 8 = 1.75GB per GPU
Basic FSDP Setup
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy, MixedPrecision
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
def setup_fsdp(config):
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
# Mixed precision for memory savings
mixed_precision = MixedPrecision(
param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16,
)
# Wrap transformer layers automatically
def auto_wrap_policy(module, recurse, nonwrapped_numel):
return nonwrapped_numel > 1e6
model = build_model(config)
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD, # Shard params and grads
mixed_precision=mixed_precision,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
use_orig_params=True # Required for gradient clipping
)
return model
Training with FSDP
def train_with_fsdp(model, train_loader, config):
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
for epoch in range(config.epochs):
for batch in train_loader:
optimizer.zero_grad()
inputs = batch["input"].cuda()
targets = batch["target"].cuda()
outputs = model(inputs)
loss = loss_fn(outputs, targets)
# FSDP handles gradient synchronization
loss.backward()
# Gradient clipping requires use_orig_params=True
model.clip_grad_norm_(max_norm=1.0)
optimizer.step()
Common FSDP Issues
Gradient checkpointing: Combine with FSDP to reduce activation memory:
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
cpu_offload=None, # Or CPUOffload(offload_params=True) for huge models
)
Local verification checkpoint
Run the smallest example from this chapter in a local workspace and record the package version, runtime, data path, and observed output. If the result depends on model size, vector count, CPU/GPU backend, or available memory, note that constraint beside the exercise so the lesson remains reproducible.
Compare GPU memory usage of DDP vs FSDP for the same model using torch.cuda.memory_allocated(). The FSDP version should use 1/N the memory per GPU.