13. Pruning: Structured vs Unstructured
Pruning removes weights deemed unimportant, reducing model size and computational requirements. The distinction between structured and unstructured pruning determines implementation complexity and hardware benefits.
Unstructured pruning removes individual weights arbitrarily. A 70% sparsity model retains 30% of weights, but the remaining weights lack regular patterns. Hardware cannot accelerate sparse computation efficiently—modern GPUs expect dense matrices.
# Torch pruning example
import torch.nn.utils.prune as prune
# Unstructured magnitude pruning - remove lowest magnitude weights
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
for name, module in model.named_modules():
if "q_proj" in name or "v_proj" in name:
prune.l1_unstructured(module, name="weight", amount=0.3)
# After pruning, fine-tune to recover accuracy
# Pruned weights remain at zero but structure is irregular
Structured pruning removes weights in regular patterns—entire attention heads, feed-forward layers, or channels. This regularity maps to hardware efficiently.
# Structured pruning - remove attention heads
from transformers import LlamaConfig
config = LlamaConfig.from_pretrained("meta-llama/Llama-2-7b-hf")
# Original: 32 attention heads
# Remove 8 heads (25% structured sparsity)
config.num_attention_heads = 24
config.hidden_size = 3584 # Adjust hidden dimension to maintain parameter ratio
pruned_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
config=config
)
Magnitude pruning removes weights below a threshold:
def magnitude_prune(model, sparsity):
"""Remove lowest magnitude weights per layer"""
for name, param in model.named_parameters():
if 'weight' in name:
threshold = torch.quantile(torch.abs(param), sparsity)
mask = torch.abs(param) > threshold
param.data *= mask.float()
return model
Movement pruning trains the model to identify which weights to remove:
# Movement pruning - train which weights to remove
def movement_prune_step(model, dataloader, sparsity=0.4):
for batch in dataloader:
outputs = model(**batch)
loss = outputs.loss
loss.backward()
# Update importance scores based on gradient magnitude
with torch.no_grad():
for name, param in model.named_parameters():
if 'weight' in name:
param.importance_scores += param.grad.abs()
return model
Hardware compatibility matters:
| Pruning Type | Hardware Acceleration | Implementation | Typical Sparsity |
|---|---|---|---|
| Unstructured | Limited (sparse kernels immature) | Simple | 50-80% |
| Structured (heads) | Full (dense computation) | Requires architecture change | 10-30% |
| Structured (channels) | Full | Requires architecture change | 20-50% |
| N:M structured | Limited (recent GPU support) | Requires new training | 33-67% |
NVIDIA A100 and H100 support structured sparsity through torch.sparse with tensor cores. RT 30/40 series has limited sparse support.
Implement 2:4 structured pruning (keep 2 weights, remove 2 in each group of 4). Compare inference speed and memory usage against unstructured pruning at equivalent sparsity.