17. Clinical Validation
Local AI deployment shifts validation responsibility to the implementing organization. Cloud services provide vendor validation and FDA clearance where applicable; self-hosted models require internal validation processes that ensure clinical safety and efficacy.
Validation encompasses multiple dimensions: accuracy compared to clinical truth, consistency across similar inputs, bias detection across demographic groups, and resilience to adversarial inputs.
# clinical_validation.py
from dataclasses import dataclass
from typing import List, Dict, Optional
from datetime import datetime
import json
from statistics import mean, stdev
@dataclass
class ValidationResult:
metric: str
value: float
target: float
passed: bool
confidence_interval: tuple
@dataclass
class ValidationReport:
timestamp: datetime
model_version: str
use_case: str
results: List[ValidationResult]
demographic_parity: Dict
edge_case_failures: List[Dict]
overall_pass: bool
class ClinicalValidator:
"""Validate clinical AI systems against established standards."""
def __init__(self, ollama_client, gold_standard_dataset_path: str):
self.ollama = ollama_client
self.gold_standard = self._load_gold_standard(gold_standard_dataset_path)
def validate_use_case(self, use_case: str,
test_cases: List[Dict]) -> ValidationReport:
"""Run validation for a specific use case."""
results = []
# Accuracy metrics
accuracy = self._calculate_accuracy(test_cases)
results.append(ValidationResult(
metric="accuracy",
value=accuracy,
target=self._get_target(use_case, "accuracy"),
passed=accuracy >= self._get_target(use_case, "accuracy"),
confidence_interval=self._confidence_interval(accuracy, len(test_cases))
))
# Precision and recall
precision = self._calculate_precision(test_cases)
recall = self._calculate_recall(test_cases)
results.append(ValidationResult(
metric="precision",
value=precision,
target=self._get_target(use_case, "precision"),
passed=precision >= self._get_target(use_case, "precision"),
confidence_interval=self._confidence_interval(precision, len(test_cases))
))
results.append(ValidationResult(
metric="recall",
value=recall,
target=self._get_target(use_case, "recall"),
passed=recall >= self._get_target(use_case, "recall"),
confidence_interval=self._confidence_interval(recall, len(test_cases))
))
# Consistency
consistency = self._calculate_consistency(test_cases)
results.append(ValidationResult(
metric="consistency",
value=consistency,
target=0.9,
passed=consistency >= 0.9,
confidence_interval=(consistency - 0.1, consistency + 0.1)
))
# Demographic parity
demographic_parity = self._check_demographic_parity(test_cases)
# Edge case failures
edge_case_failures = self._identify_edge_case_failures(test_cases)
overall_pass = all(r.passed for r in results) and len(edge_case_failures) == 0
return ValidationReport(
timestamp=datetime.utcnow(),
model_version=self._get_model_version(),
use_case=use_case,
results=results,
demographic_parity=demographic_parity,
edge_case_failures=edge_case_failures,
overall_pass=overall_pass
)
def _calculate_accuracy(self, test_cases: List[Dict]) -> float:
"""Calculate accuracy against gold standard."""
correct = 0
for case in test_cases:
prediction = self._predict(case["input"])
if self._matches_gold(prediction, case["expected"]):
correct += 1
return correct / len(test_cases) if test_cases else 0
def _calculate_precision(self, test_cases: List[Dict]) -> float:
"""Calculate precision (positive predictive value)."""
true_positives = 0
predicted_positives = 0
for case in test_cases:
prediction = self._predict(case["input"])
if self._is_positive_prediction(prediction):
predicted_positives += 1
if self._matches_gold(prediction, case["expected"]):
true_positives += 1
return true_positives / predicted_positives if predicted_positives > 0 else 0
def _calculate_recall(self, test_cases: List[Dict]) -> float:
"""Calculate recall (sensitivity)."""
true_positives = 0
actual_positives = 0
for case in test_cases:
prediction = self._predict(case["input"])
if self._is_positive_prediction(case["expected"]):
actual_positives += 1
if self._is_positive_prediction(prediction) and \
self._matches_gold(prediction, case["expected"]):
true_positives += 1
return true_positives / actual_positives if actual_positives > 0 else 0
def _calculate_consistency(self, test_cases: List[Dict]) -> float:
"""Calculate output consistency for identical inputs."""
# Group by input hash
input_groups = {}
for case in test_cases:
input_hash = hash(case["input"])
if input_hash not in input_groups:
input_groups[input_hash] = []
input_groups[input_hash].append(self._predict(case["input"]))
# Calculate consistency within groups
consistent = 0
total = 0
for group in input_groups.values():
if len(group) > 1:
# Count pairs that are consistent
for i, pred in enumerate(group):
for other in group[i+1:]:
total += 1
if self._predictions_equivalent(pred, other):
consistent += 1
return consistent / total if total > 0 else 1.0
def _check_demographic_parity(self, test_cases: List[Dict]) -> Dict:
"""Check for performance differences across demographic groups."""
demographic_results = {}
# Group by demographic
for case in test_cases:
demo = case.get("demographics", "unknown")
if demo not in demographic_results:
demographic_results[demo] = {"correct": 0, "total": 0}
demographic_results[demo]["total"] += 1
prediction = self._predict(case["input"])
if self._matches_gold(prediction, case["expected"]):
demographic_results[demo]["correct"] += 1
# Calculate rates
parity_check = {}
for demo, results in demographic_results.items():
rate = results["correct"] / results["total"] if results["total"] > 0 else 0
parity_check[demo] = {
"accuracy": rate,
"sample_size": results["total"]
}
return parity_check
def _identify_edge_case_failures(self, test_cases: List[Dict]) -> List[Dict]:
"""Identify specific cases where model fails."""
failures = []
for case in test_cases:
prediction = self._predict(case["input"])
if not self._matches_gold(prediction, case["expected"]):
failures.append({
"input": case["input"][:200], # Truncate for logging
"expected": case["expected"],
"predicted": prediction,
"category": case.get("category", "unknown"),
"severity": self._assess_failure_severity(
case["expected"], prediction
)
})
return failures
def _get_target(self, use_case: str, metric: str) -> float:
"""Get validation targets for use case and metric."""
targets = {
"clinical_note_processing": {
"accuracy": 0.95,
"precision": 0.90,
"recall": 0.90
},
"decision_support": {
"accuracy": 0.98,
"precision": 0.95,
"recall": 0.95
},
"medical_coding": {
"accuracy": 0.90,
"precision": 0.85,
"recall": 0.85
}
}
return targets.get(use_case, {}).get(metric, 0.85)
def _confidence_interval(self, proportion: float, n: int,
confidence: float = 0.95) -> tuple:
"""Calculate Wilson score confidence interval."""
import math
z = 1.96 # 95% confidence
denominator = 1 + z**2 / n
center = (proportion + z**2 / (2*n)) / denominator
margin = z * math.sqrt((proportion*(1-proportion) + z**2/(4*n)) / n) / denominator
return (center - margin, center + margin)
def _predict(self, input_data: str) -> str:
"""Run model prediction."""
return self.ollama.generate(input_data)
def _matches_gold(self, prediction: str, gold: str) -> bool:
"""Check if prediction matches gold standard."""
# Flexible matching for clinical text
return gold.lower() in prediction.lower() or prediction.lower() in gold.lower()
def _is_positive_prediction(self, prediction: str) -> bool:
"""Check if prediction is considered positive."""
return "positive" in prediction.lower() or "yes" in prediction.lower()
def _predictions_equivalent(self, pred1: str, pred2: str) -> bool:
"""Check if two predictions are equivalent."""
return self._matches_gold(pred1, pred2)
def _assess_failure_severity(self, expected: str, predicted: str) -> str:
"""Assess clinical severity of prediction failure."""
if "critical" in expected.lower() and "critical" not in predicted.lower():
return "high"
return "moderate"
def _get_model_version(self) -> str:
"""Get current model version identifier."""
return self.ollama.show("llama3.2").get("version", "unknown")
def _load_gold_standard(self, path: str) -> List[Dict]:
"""Load gold standard test cases."""
with open(path, 'r') as f:
return json.load(f)
def generate_validation_report(self, report: ValidationReport) -> str:
"""Generate human-readable validation report."""
lines = [
f"# Clinical AI Validation Report",
f"",
f"**Timestamp**: {report.timestamp.isoformat()}",
f"**Model Version**: {report.model_version}",
f"**Use Case**: {report.use_case}",
f"**Overall Result**: {'PASS' if report.overall_pass else 'FAIL'}",
f"",
f"## Validation Metrics",
]
for result in report.results:
status = "✓" if result.passed else "✗"
lines.append(f"- {status} **{result.metric}**: {result.value:.2%} (target: {result.target:.2%})")
lines.extend([
f"",
f"## Demographic Parity",
])
for demo, stats in report.demographic_parity.items():
lines.append(f"- {demo}: {stats['accuracy']:.2%} (n={stats['sample_size']})")
if report.edge_case_failures:
lines.extend([
f"",
f"## Edge Case Failures ({len(report.edge_case_failures)})",
])
for failure in report.edge_case_failures[:10]: # Limit to first 10
lines.append(f"- [{failure['severity']}] {failure['category']}: {failure['input']}")
return "\n".join(lines)
Validation requires representative test data that covers the expected input distribution. Healthcare AI often fails on rare conditions or edge cases that aren't well-represented in general training data. Include diverse clinical scenarios in validation datasets, with particular attention to high-stakes scenarios.
Build a validation pipeline for a clinical use case. Create 100 test cases with gold standard answers, run the validation suite, and document which cases fail and why. Identify patterns in failures.