10. Model Pruning for Edge
Chapter 10 of 18 · 20 min
Pruning removes weights from neural networks based on their contribution to output. Unstructured pruning removes individual weights producing sparsity; structured pruning removes entire neurons, channels, or attention heads, producing models that run efficiently on standard hardware.
Magnitude-based pruning identifies low-magnitude weights as candidates for removal:
import torch
import torch.nn.utils.prune as prune
def global_magnitude_prune(model, param_name, amount=0.5):
"""Remove fraction of weights globally across parameter"""
parameters_to_prune = []
for name, module in model.named_modules():
if hasattr(module, param_name):
parameters_to_prune.append((module, param_name))
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=amount
)
# Apply 50% pruning globally
global_magnitude_prune(model, 'weight', amount=0.5)
# Check sparsity
for name, module in model.named_modules():
if hasattr(module, 'weight'):
mask = module.weight_mask
sparsity = (mask == 0).sum() / mask.numel()
if sparsity > 0.01:
print(f"{name} sparsity: {sparsity:.1%}")
Structured pruning removes entire channels, simplifying model architecture:
def channel_prune_by_l2(model, layer_idx, prune_ratio=0.3):
"""Remove channels based on L2 norm"""
layer = model.features[layer_idx]
# Compute L2 norm of each output channel
channel_norms = torch.norm(layer.weight.data, p=2, dim=(1, 2, 3))
# Identify channels to prune
threshold = torch.quantile(channel_norms, prune_ratio)
prune_mask = channel_norms > threshold
# Create new smaller layer
num_kept = prune_mask.sum().item()
new_layer = nn.Conv2d(
layer.in_channels,
num_kept,
layer.kernel_size,
stride=layer.stride,
padding=layer.padding
)
# Copy surviving channel weights (transpose for output channel axis)
idx = 0
for i, keep in enumerate(prune_mask.tolist()):
if keep:
new_layer.weight.data[idx] = layer.weight.data[i]
idx += 1
return new_layer, prune_mask
Iterative pruning with fine-tuning preserves accuracy:
def iterative_prune_finetune(model, train_loader, device, target_sparsity=0.7, steps=20):
current_sparsity = 0.0
while current_sparsity < target_sparsity:
# Magnitude prune some percentage
prune_amount = 0.1 / steps
global_magnitude_prune(model, 'weight', amount=prune_amount)
# Fine-tune for one epoch
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
for inputs, targets in train_loader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# Measure current sparsity
total_params = sum(p.numel() for p in model.parameters() if hasattr(p, 'mask'))
pruned_params = sum((p.mask == 0).sum() for p in model.parameters() if hasattr(p, 'mask'))
current_sparsity = pruned_params / total_params
print(f"Current sparsity: {current_sparsity:.1%}")
Pruning recipes for common architectures:
# ResNet pruning: prune residual connections with 0.4 weight L2 norm threshold
prune.l1_unstructured(module_name.conv2, name='weight', amount=0.4)
# MobileNetV2: prune depthwise separable convolutions less aggressively
prune.l1_unstructured(depthwise_module, name='weight', amount=0.3)
# Transformer attention heads: structured removal of entire heads
# Better to use scipy.sparse matrices for efficient attention
EXERCISE
Implement iterative magnitude pruning on a CNN classifier, track accuracy across pruning steps, and measure final inference speedup on edge hardware.