13. Pruning: Structured vs Unstructured

Chapter 13 of 18 · 20 min

Pruning removes weights deemed unimportant, reducing model size and computational requirements. The distinction between structured and unstructured pruning determines implementation complexity and hardware benefits.

Unstructured pruning removes individual weights arbitrarily. A 70% sparsity model retains 30% of weights, but the remaining weights lack regular patterns. Hardware cannot accelerate sparse computation efficiently—modern GPUs expect dense matrices.

# Torch pruning example
import torch.nn.utils.prune as prune

# Unstructured magnitude pruning - remove lowest magnitude weights
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")

for name, module in model.named_modules():
    if "q_proj" in name or "v_proj" in name:
        prune.l1_unstructured(module, name="weight", amount=0.3)

# After pruning, fine-tune to recover accuracy
# Pruned weights remain at zero but structure is irregular

Structured pruning removes weights in regular patterns—entire attention heads, feed-forward layers, or channels. This regularity maps to hardware efficiently.

# Structured pruning - remove attention heads
from transformers import LlamaConfig

config = LlamaConfig.from_pretrained("meta-llama/Llama-2-7b-hf")
# Original: 32 attention heads
# Remove 8 heads (25% structured sparsity)
config.num_attention_heads = 24
config.hidden_size = 3584  # Adjust hidden dimension to maintain parameter ratio

pruned_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    config=config
)

Magnitude pruning removes weights below a threshold:

def magnitude_prune(model, sparsity):
    """Remove lowest magnitude weights per layer"""
    for name, param in model.named_parameters():
        if 'weight' in name:
            threshold = torch.quantile(torch.abs(param), sparsity)
            mask = torch.abs(param) > threshold
            param.data *= mask.float()
    return model

Movement pruning trains the model to identify which weights to remove:

# Movement pruning - train which weights to remove
def movement_prune_step(model, dataloader, sparsity=0.4):
    for batch in dataloader:
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        
        # Update importance scores based on gradient magnitude
        with torch.no_grad():
            for name, param in model.named_parameters():
                if 'weight' in name:
                    param.importance_scores += param.grad.abs()
    return model

Hardware compatibility matters:

Pruning Type Hardware Acceleration Implementation Typical Sparsity
Unstructured Limited (sparse kernels immature) Simple 50-80%
Structured (heads) Full (dense computation) Requires architecture change 10-30%
Structured (channels) Full Requires architecture change 20-50%
N:M structured Limited (recent GPU support) Requires new training 33-67%

NVIDIA A100 and H100 support structured sparsity through torch.sparse with tensor cores. RT 30/40 series has limited sparse support.

EXERCISE

Implement 2:4 structured pruning (keep 2 weights, remove 2 in each group of 4). Compare inference speed and memory usage against unstructured pruning at equivalent sparsity.