KEY INSIGHT
Joint optimization of multiple compression techniques often outperforms sequential application, but requires careful gradient coordination to avoid conflicting objectives.
While sequential pipelines are simpler to implement, joint compression allows techniques to adapt to each other's effects. This is especially important when compression methods have interdependent effects on the loss landscape.
### Joint Optimization Framework
```python
class JointCompression:
def __init__(self, model, compression_params):
self.model = model
# Masks for pruning (binary)
self.prune_mask = torch.ones_like(model.weight, dtype=torch.bool)
# Scale factors for quantization (learnable)
self.quant_scales = nn.Parameter(torch.ones_like(model.weight))
def forward(self, x):
# Apply pruning mask
weight_pruned = self.model.weight * self.prune_mask.float()
# Apply quantization scaling
weight_quant = weight_pruned * self.quant_scales
# Quantize to target precision
weight_discrete = self.round_ste(weight_quant) # STE for gradients
# Forward pass with discrete weights
return F.linear(x, weight_discrete, self.model.bias)
def loss(self, output, target, teacher_output=None):
task_loss = F.cross_entropy(output, target)
# Pruning regularization: encourage sparsity
prune_reg = 0.01 * self.prune_mask.float().mean()
# Quantization regularization: encourage scales toward uniform
scale_reg = 0.01 * (self.quant_scales.std() + 1e-6)
# Distillation loss if teacher available
distill_loss = 0
if teacher_output is not None:
distill_loss = 0.5 * F.kl_div(
F.log_softmax(output / 4.0, dim=-1),
F.log_softmax(teacher_output / 4.0, dim=-1)
)
return task_loss + prune_reg + scale_reg + distill_loss
```
### Gradient Coordination
When pruning and quantization gradients conflict, optimization becomes unstable. Pruning gradients encourage certain weights to zero, while quantization gradients encourage uniform scaling. Without coordination, the model oscillates.
```python
def compute_coordinated_gradients(loss, model, prune_mask, quant_scales):
# Compute gradients for each compression technique separately
grad_task = torch.autograd.grad(loss, model.parameters(), retain_graph=True)
grad_prune = torch.autograd.grad(loss, prune_mask, retain_graph=True)
grad_quant = torch.autograd.grad(loss, quant_scales)
# Detect conflicts: opposite sign gradients
prune_conflict = detect_conflicts(grad_prune, grad_quant)
if prune_conflict > 0.3: # Threshold for conflict
# Reduce learning rate for conflicting components
lr_reduction = 0.5
grad_prune = [g * lr_reduction for g in grad_prune]
return grad_task, grad_prune, grad_quant
```
### Practical Implementation
Joint compression works best with iterative updates:
1. Initialize all masks and scales uniformly
2. Perform several gradient steps jointly
3. Periodically sharpen masks (push toward binary) and scales
4. Evaluate after each cycle for convergence
```python
def joint_compress_loop(model, train_loader, epochs=100):
compressor = JointCompression(model)
optimizer = torch.optim.Adam([
{'params': model.parameters()},
{'params': compressor.prune_mask},
{'params': compressor.quant_scales, 'lr': 0.01}
], lr=0.001)
for epoch in range(epochs):
for batch in train_loader:
optimizer.zero_grad()
output = compressor(batch['input'])
loss = compressor.loss(output, batch['target'])
loss.backward()
# Gradient coordination
coordinated_grads = compute_coordinated_gradients(
loss, model,
compressor.prune_mask,
compressor.quant_scales
)
optimizer.step()
# Periodic sharpening
if epoch % 10 == 0:
compressor.sharpen_masks()
compressor.evaluate(model, eval_loader)
```
### When Joint Beats Sequential
Joint compression excels when:
- Compression techniques compete for the same weights
- Target compression ratio is aggressive (>80%)
- Limited fine-tuning data makes each technique's accuracy recovery critical
Sequential pipelines remain valuable for simpler models or when interpretability of each stage matters.