KEY INSIGHT
Successful deployment requires not just compression but also proper export, runtime configuration, and monitoring infrastructure to maintain model quality in production.
Deploying compressed models involves more than converting weights. The entire inference pipeline must adapt to compressed representations while maintaining reliability.
### Model Export
```python
import torch
import onnx
class CompressedModelExporter:
def __init__(self, model):
self.model = model
def export_to_onnx(self, output_path, input_shape, opset_version=13):
"""
Export compressed model to ONNX format.
Handles quantization nodes and pruned tensors.
"""
self.model.eval()
# Create dummy input matching expected shape
dummy_input = torch.randn(input_shape)
# Export to ONNX
torch.onnx.export(
self.model,
dummy_input,
output_path,
export_params=True,
opset_version=opset_version,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
# Verify export
self._verify_onnx(output_path, input_shape)
return output_path
def _verify_onnx(self, onnx_path, input_shape):
"""Verify ONNX model loads and produces valid outputs."""
import onnxruntime as ort
session = ort.InferenceSession(onnx_path)
# Run inference
input_data = np.random.randn(*input_shape).astype(np.float32)
output = session.run(None, {'input': input_data})
assert len(output) > 0, "ONNX model produced no outputs"
assert not np.any(np.isnan(output[0])), "ONNX output contains NaN"
assert not np.any(np.isinf(output[0])), "ONNX output contains Inf"
def export_with_quantization(self, output_path, calibration_data):
"""
Export model with post-training quantization applied.
"""
import torch.quantization as tq
# Prepare model for quantization
quantized_model = torch.quantization.quantize_ptq(
self.model,
tq.get_default_qconfig('fbgemm'),
calibration_data
)
# Export quantized model
torch.save({
'state_dict': quantized_model.state_dict(),
'quantization_config': quantized_model.qconfig,
'architecture': type(self.model).__name__
}, output_path)
return output_path
```
### Runtime Configuration
```python
class CompressedInferenceEngine:
def __init__(self, model_path, device='cpu'):
self.device = device
self.model = self._load_model(model_path)
def _load_model(self, model_path):
"""Load model with appropriate runtime settings."""
if model_path.endswith('.onnx'):
return self._load_onnx(model_path)
elif model_path.endswith('.pt'):
return self._load_torch(model_path)
else:
raise ValueError(f"Unsupported format: {model_path}")
def _load_onnx(self, model_path):
"""Load ONNX model with optimized runtime."""
import onnxruntime as ort
providers = {
'cpu': ['CPUExecutionProvider'],
'cuda': ['CUDAExecutionProvider', 'CPUExecutionProvider'],
'tensorrt': ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']
}
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session = ort.InferenceSession(
model_path,
sess_options,
providers=providers.get(self.device, providers['cpu'])
)
return session
def predict(self, inputs):
"""
Run inference with error handling.
"""
try:
if isinstance(self.model, ort.InferenceSession):
# ONNX inference
input_name = self.model.get_inputs()[0].name
output_name = self.model.get_outputs()[0].name
outputs = self.model.run([output_name], {input_name: inputs})
return outputs[0]
else:
# PyTorch inference
with torch.no_grad():
return self.model(torch.from_numpy(inputs)).numpy()
except Exception as e:
logging.error(f"Inference failed: {e}")
return self._fallback_predict(inputs)
def _fallback_predict(self, inputs):
"""
Fallback prediction for when primary model fails.
Could load a backup model or return cached predictions.
"""
raise RuntimeError("Both primary and fallback inference failed")
```
### Monitoring in Production
```python
import logging
class ModelMonitor:
def __init__(self, model, metrics_backend):
self.model = model
self.backend = metrics_backend
self.prediction_count = 0
def predict_with_monitoring(self, inputs):
"""
Run prediction and track metrics.
"""
start_time = time.perf_counter()
# Run inference
outputs = self.model.predict(inputs)
# Track latency
latency_ms = (time.perf_counter() - start_time) * 1000
# Track predictions
self.prediction_count += 1
# Report metrics
self.backend.gauge('model_latency_ms', latency_ms)
self.backend.gauge('predictions_total', self.prediction_count)
# Check for anomalies
if np.any(np.isnan(outputs)):
logging.warning(f"NaN detected in prediction {self.prediction_count}")
self.backend.increment('nan_predictions')
if np.any(np.abs(outputs) > 100):
logging.warning(f"Unusual output magnitude detected")
self.backend.increment('outlier_predictions')
return outputs
def report_accuracy(self, y_true, y_pred):
"""Track batch accuracy for monitoring drift."""
accuracy = np.mean(y_true == y_pred)
self.backend.gauge('accuracy', accuracy)
# Check for accuracy degradation
if accuracy < self.baseline_accuracy - 0.05:
self._alert_accuracy_degradation()
def _alert_accuracy_degradation(self):
"""Alert when accuracy drops below threshold."""
logging.critical(
f"Model accuracy degraded below acceptable threshold. "
f"Predictions: {self.prediction_count}"
)
```
### Deployment Checklist
Before production deployment:
- [ ] Verify accuracy on held-out test set
- [ ] Benchmark latency on target hardware
- [ ] Check model file size and memory requirements
- [ ] Validate ONNX export produces correct outputs
- [ ] Test with production traffic patterns
- [ ] Set up monitoring and alerting
- [ ] Prepare rollback procedure
- [ ] Document compression configuration and date