17. Clinical Validation

Chapter 17 of 18 · 25 min

Local AI deployment shifts validation responsibility to the implementing organization. Cloud services provide vendor validation and FDA clearance where applicable; self-hosted models require internal validation processes that ensure clinical safety and efficacy.

Validation encompasses multiple dimensions: accuracy compared to clinical truth, consistency across similar inputs, bias detection across demographic groups, and resilience to adversarial inputs.

# clinical_validation.py
from dataclasses import dataclass
from typing import List, Dict, Optional
from datetime import datetime
import json
from statistics import mean, stdev

@dataclass
class ValidationResult:
    metric: str
    value: float
    target: float
    passed: bool
    confidence_interval: tuple

@dataclass
class ValidationReport:
    timestamp: datetime
    model_version: str
    use_case: str
    results: List[ValidationResult]
    demographic_parity: Dict
    edge_case_failures: List[Dict]
    overall_pass: bool

class ClinicalValidator:
    """Validate clinical AI systems against established standards."""
    
    def __init__(self, ollama_client, gold_standard_dataset_path: str):
        self.ollama = ollama_client
        self.gold_standard = self._load_gold_standard(gold_standard_dataset_path)
        
    def validate_use_case(self, use_case: str, 
                          test_cases: List[Dict]) -> ValidationReport:
        """Run validation for a specific use case."""
        
        results = []
        
        # Accuracy metrics
        accuracy = self._calculate_accuracy(test_cases)
        results.append(ValidationResult(
            metric="accuracy",
            value=accuracy,
            target=self._get_target(use_case, "accuracy"),
            passed=accuracy >= self._get_target(use_case, "accuracy"),
            confidence_interval=self._confidence_interval(accuracy, len(test_cases))
        ))
        
        # Precision and recall
        precision = self._calculate_precision(test_cases)
        recall = self._calculate_recall(test_cases)
        results.append(ValidationResult(
            metric="precision",
            value=precision,
            target=self._get_target(use_case, "precision"),
            passed=precision >= self._get_target(use_case, "precision"),
            confidence_interval=self._confidence_interval(precision, len(test_cases))
        ))
        results.append(ValidationResult(
            metric="recall",
            value=recall,
            target=self._get_target(use_case, "recall"),
            passed=recall >= self._get_target(use_case, "recall"),
            confidence_interval=self._confidence_interval(recall, len(test_cases))
        ))
        
        # Consistency
        consistency = self._calculate_consistency(test_cases)
        results.append(ValidationResult(
            metric="consistency",
            value=consistency,
            target=0.9,
            passed=consistency >= 0.9,
            confidence_interval=(consistency - 0.1, consistency + 0.1)
        ))
        
        # Demographic parity
        demographic_parity = self._check_demographic_parity(test_cases)
        
        # Edge case failures
        edge_case_failures = self._identify_edge_case_failures(test_cases)
        
        overall_pass = all(r.passed for r in results) and len(edge_case_failures) == 0
        
        return ValidationReport(
            timestamp=datetime.utcnow(),
            model_version=self._get_model_version(),
            use_case=use_case,
            results=results,
            demographic_parity=demographic_parity,
            edge_case_failures=edge_case_failures,
            overall_pass=overall_pass
        )
    
    def _calculate_accuracy(self, test_cases: List[Dict]) -> float:
        """Calculate accuracy against gold standard."""
        correct = 0
        for case in test_cases:
            prediction = self._predict(case["input"])
            if self._matches_gold(prediction, case["expected"]):
                correct += 1
        return correct / len(test_cases) if test_cases else 0
    
    def _calculate_precision(self, test_cases: List[Dict]) -> float:
        """Calculate precision (positive predictive value)."""
        true_positives = 0
        predicted_positives = 0
        
        for case in test_cases:
            prediction = self._predict(case["input"])
            if self._is_positive_prediction(prediction):
                predicted_positives += 1
                if self._matches_gold(prediction, case["expected"]):
                    true_positives += 1
        
        return true_positives / predicted_positives if predicted_positives > 0 else 0
    
    def _calculate_recall(self, test_cases: List[Dict]) -> float:
        """Calculate recall (sensitivity)."""
        true_positives = 0
        actual_positives = 0
        
        for case in test_cases:
            prediction = self._predict(case["input"])
            if self._is_positive_prediction(case["expected"]):
                actual_positives += 1
                if self._is_positive_prediction(prediction) and \
                   self._matches_gold(prediction, case["expected"]):
                    true_positives += 1
        
        return true_positives / actual_positives if actual_positives > 0 else 0
    
    def _calculate_consistency(self, test_cases: List[Dict]) -> float:
        """Calculate output consistency for identical inputs."""
        # Group by input hash
        input_groups = {}
        for case in test_cases:
            input_hash = hash(case["input"])
            if input_hash not in input_groups:
                input_groups[input_hash] = []
            input_groups[input_hash].append(self._predict(case["input"]))
        
        # Calculate consistency within groups
        consistent = 0
        total = 0
        for group in input_groups.values():
            if len(group) > 1:
                # Count pairs that are consistent
                for i, pred in enumerate(group):
                    for other in group[i+1:]:
                        total += 1
                        if self._predictions_equivalent(pred, other):
                            consistent += 1
        
        return consistent / total if total > 0 else 1.0
    
    def _check_demographic_parity(self, test_cases: List[Dict]) -> Dict:
        """Check for performance differences across demographic groups."""
        demographic_results = {}
        
        # Group by demographic
        for case in test_cases:
            demo = case.get("demographics", "unknown")
            if demo not in demographic_results:
                demographic_results[demo] = {"correct": 0, "total": 0}
            
            demographic_results[demo]["total"] += 1
            prediction = self._predict(case["input"])
            if self._matches_gold(prediction, case["expected"]):
                demographic_results[demo]["correct"] += 1
        
        # Calculate rates
        parity_check = {}
        for demo, results in demographic_results.items():
            rate = results["correct"] / results["total"] if results["total"] > 0 else 0
            parity_check[demo] = {
                "accuracy": rate,
                "sample_size": results["total"]
            }
        
        return parity_check
    
    def _identify_edge_case_failures(self, test_cases: List[Dict]) -> List[Dict]:
        """Identify specific cases where model fails."""
        failures = []
        
        for case in test_cases:
            prediction = self._predict(case["input"])
            if not self._matches_gold(prediction, case["expected"]):
                failures.append({
                    "input": case["input"][:200],  # Truncate for logging
                    "expected": case["expected"],
                    "predicted": prediction,
                    "category": case.get("category", "unknown"),
                    "severity": self._assess_failure_severity(
                        case["expected"], prediction
                    )
                })
        
        return failures
    
    def _get_target(self, use_case: str, metric: str) -> float:
        """Get validation targets for use case and metric."""
        targets = {
            "clinical_note_processing": {
                "accuracy": 0.95,
                "precision": 0.90,
                "recall": 0.90
            },
            "decision_support": {
                "accuracy": 0.98,
                "precision": 0.95,
                "recall": 0.95
            },
            "medical_coding": {
                "accuracy": 0.90,
                "precision": 0.85,
                "recall": 0.85
            }
        }
        
        return targets.get(use_case, {}).get(metric, 0.85)
    
    def _confidence_interval(self, proportion: float, n: int, 
                            confidence: float = 0.95) -> tuple:
        """Calculate Wilson score confidence interval."""
        import math
        z = 1.96  # 95% confidence
        denominator = 1 + z**2 / n
        center = (proportion + z**2 / (2*n)) / denominator
        margin = z * math.sqrt((proportion*(1-proportion) + z**2/(4*n)) / n) / denominator
        return (center - margin, center + margin)
    
    def _predict(self, input_data: str) -> str:
        """Run model prediction."""
        return self.ollama.generate(input_data)
    
    def _matches_gold(self, prediction: str, gold: str) -> bool:
        """Check if prediction matches gold standard."""
        # Flexible matching for clinical text
        return gold.lower() in prediction.lower() or prediction.lower() in gold.lower()
    
    def _is_positive_prediction(self, prediction: str) -> bool:
        """Check if prediction is considered positive."""
        return "positive" in prediction.lower() or "yes" in prediction.lower()
    
    def _predictions_equivalent(self, pred1: str, pred2: str) -> bool:
        """Check if two predictions are equivalent."""
        return self._matches_gold(pred1, pred2)
    
    def _assess_failure_severity(self, expected: str, predicted: str) -> str:
        """Assess clinical severity of prediction failure."""
        if "critical" in expected.lower() and "critical" not in predicted.lower():
            return "high"
        return "moderate"
    
    def _get_model_version(self) -> str:
        """Get current model version identifier."""
        return self.ollama.show("llama3.2").get("version", "unknown")
    
    def _load_gold_standard(self, path: str) -> List[Dict]:
        """Load gold standard test cases."""
        with open(path, 'r') as f:
            return json.load(f)
    
    def generate_validation_report(self, report: ValidationReport) -> str:
        """Generate human-readable validation report."""
        lines = [
            f"# Clinical AI Validation Report",
            f"",
            f"**Timestamp**: {report.timestamp.isoformat()}",
            f"**Model Version**: {report.model_version}",
            f"**Use Case**: {report.use_case}",
            f"**Overall Result**: {'PASS' if report.overall_pass else 'FAIL'}",
            f"",
            f"## Validation Metrics",
        ]
        
        for result in report.results:
            status = "✓" if result.passed else "✗"
            lines.append(f"- {status} **{result.metric}**: {result.value:.2%} (target: {result.target:.2%})")
        
        lines.extend([
            f"",
            f"## Demographic Parity",
        ])
        
        for demo, stats in report.demographic_parity.items():
            lines.append(f"- {demo}: {stats['accuracy']:.2%} (n={stats['sample_size']})")
        
        if report.edge_case_failures:
            lines.extend([
                f"",
                f"## Edge Case Failures ({len(report.edge_case_failures)})",
            ])
            for failure in report.edge_case_failures[:10]:  # Limit to first 10
                lines.append(f"- [{failure['severity']}] {failure['category']}: {failure['input']}")
        
        return "\n".join(lines)

Validation requires representative test data that covers the expected input distribution. Healthcare AI often fails on rare conditions or edge cases that aren't well-represented in general training data. Include diverse clinical scenarios in validation datasets, with particular attention to high-stakes scenarios.

EXERCISE

Build a validation pipeline for a clinical use case. Create 100 test cases with gold standard answers, run the validation suite, and document which cases fail and why. Identify patterns in failures.