10. Model Pruning for Edge

Chapter 10 of 18 · 20 min

Pruning removes weights from neural networks based on their contribution to output. Unstructured pruning removes individual weights producing sparsity; structured pruning removes entire neurons, channels, or attention heads, producing models that run efficiently on standard hardware.

Magnitude-based pruning identifies low-magnitude weights as candidates for removal:

import torch
import torch.nn.utils.prune as prune

def global_magnitude_prune(model, param_name, amount=0.5):
    """Remove fraction of weights globally across parameter"""
    parameters_to_prune = []
    for name, module in model.named_modules():
        if hasattr(module, param_name):
            parameters_to_prune.append((module, param_name))
    
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=amount
    )

# Apply 50% pruning globally
global_magnitude_prune(model, 'weight', amount=0.5)

# Check sparsity
for name, module in model.named_modules():
    if hasattr(module, 'weight'):
        mask = module.weight_mask
        sparsity = (mask == 0).sum() / mask.numel()
        if sparsity > 0.01:
            print(f"{name} sparsity: {sparsity:.1%}")

Structured pruning removes entire channels, simplifying model architecture:

def channel_prune_by_l2(model, layer_idx, prune_ratio=0.3):
    """Remove channels based on L2 norm"""
    layer = model.features[layer_idx]
    
    # Compute L2 norm of each output channel
    channel_norms = torch.norm(layer.weight.data, p=2, dim=(1, 2, 3))
    
    # Identify channels to prune
    threshold = torch.quantile(channel_norms, prune_ratio)
    prune_mask = channel_norms > threshold
    
    # Create new smaller layer
    num_kept = prune_mask.sum().item()
    new_layer = nn.Conv2d(
        layer.in_channels,
        num_kept,
        layer.kernel_size,
        stride=layer.stride,
        padding=layer.padding
    )
    
    # Copy surviving channel weights (transpose for output channel axis)
    idx = 0
    for i, keep in enumerate(prune_mask.tolist()):
        if keep:
            new_layer.weight.data[idx] = layer.weight.data[i]
            idx += 1
    
    return new_layer, prune_mask

Iterative pruning with fine-tuning preserves accuracy:

def iterative_prune_finetune(model, train_loader, device, target_sparsity=0.7, steps=20):
    current_sparsity = 0.0
    
    while current_sparsity < target_sparsity:
        # Magnitude prune some percentage
        prune_amount = 0.1 / steps
        global_magnitude_prune(model, 'weight', amount=prune_amount)
        
        # Fine-tune for one epoch
        model.train()
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
        criterion = nn.CrossEntropyLoss()
        
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        
        # Measure current sparsity
        total_params = sum(p.numel() for p in model.parameters() if hasattr(p, 'mask'))
        pruned_params = sum((p.mask == 0).sum() for p in model.parameters() if hasattr(p, 'mask'))
        current_sparsity = pruned_params / total_params
        
        print(f"Current sparsity: {current_sparsity:.1%}")

Pruning recipes for common architectures:

# ResNet pruning: prune residual connections with 0.4 weight L2 norm threshold
prune.l1_unstructured(module_name.conv2, name='weight', amount=0.4)

# MobileNetV2: prune depthwise separable convolutions less aggressively
prune.l1_unstructured(depthwise_module, name='weight', amount=0.3)

# Transformer attention heads: structured removal of entire heads
# Better to use scipy.sparse matrices for efficient attention
EXERCISE

Implement iterative magnitude pruning on a CNN classifier, track accuracy across pruning steps, and measure final inference speedup on edge hardware.