08. Interpretability Overview
Interpretability provides visibility into how AI models produce outputs. For operators, this understanding enables better safety assessments, targeted improvements, and trust calibration.
Why Interpretability Matters for Safety
Safety evaluations require understanding model behavior, not just observing inputs and outputs. A model that refuses harmful requests might do so for the right reasons or for spurious correlations. Without interpretability, operators cannot distinguish these cases.
Consider a safety-critical decision:
# Model suggests refusing a request
# Is it refusing because:
# 1. It correctly identified harmful content? (Good)
# 2. It detected a specific trigger word unrelated to actual harm? (Fragile)
# 3. It learned that requests containing one word correlate with human review? (Manipulable)
# Interpretability reveals which mechanism is active
Interpretability also enables targeted improvements. Instead of guessing why a model misbehaves, operators can examine specific circuits and modify them directly.
Levels of Interpretability
Interpretability exists at multiple granularities:
Token-level analysis examines how models process and generate individual tokens. Which input tokens most influence each output token? How do attention patterns flow through the network?
Layer-level analysis studies representations learned by different network layers. Early layers capture syntax; later layers capture semantics. Understanding this progression reveals where safety-relevant information is processed.
Circuit-level analysis identifies specific subgraphs that implement particular behaviors. A circuit might detect harmful intent, trigger refusal, or suppress safety responses.
# Token attribution example
def token_attribution(model, input_text, output_start_idx):
"""Attribute output tokens to input tokens"""
# Get model predictions
baseline_logits = model(input_text)
attributions = []
for output_pos in range(output_start_idx, output_start_idx + 5):
# Get output token at this position
output_token = baseline_logits[0, output_pos].argmax()
# Measure importance of each input token
token_importances = []
for input_pos in range(len(input_text.tokens)):
# Run with token removed
masked_input = remove_token(input_text, input_pos)
masked_logits = model(masked_input)
# Difference reveals importance
importance = baseline_logits[0, output_pos, output_token] - \
masked_logits[0, output_pos, output_token]
token_importances.append(float(importance))
attributions.append({
"output_token": model.tokenizer.decode([output_token]),
"input_attributions": token_importances
})
return attributions
Interpretability Methods
Several established techniques provide model insights:
Attention analysis studies attention weights to understand which input tokens influence which output tokens. High attention from output token to a specific input token suggests strong influence:
def analyze_attention(model, input_text):
"""Extract and analyze attention patterns"""
inputs = model.tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs, output_attentions=True)
# Attention weights: [batch, heads, seq_len_out, seq_len_in]
attentions = outputs.attentions
# Average across layers and heads
avg_attention = torch.mean(
torch.stack(attentions),
dim=[0, 1]
)
# Analyze strongest connections
top_connections = []
seq_len = avg_attention.shape[-1]
for out_pos in range(seq_len):
for in_pos in range(seq_len):
weight = avg_attention[0, out_pos, in_pos].item()
if weight > 0.1: # Threshold for significance
top_connections.append({
"from": in_pos,
"to": out_pos,
"weight": weight
})
return top_connections
Probing classifiers train small classifiers on internal representations to detect specific features. If a classifier can detect harmful intent from layer 15 representations better than layer 5, safety-relevant processing occurs between those layers.
Feature ablation systematically removes features to measure their contribution. Compare model behavior with and without specific neurons, attention heads, or layers.
Implement token attribution for a local model processing a test input. Identify which input tokens have the strongest influence on the output. Analyze whether this attribution matches expected behavior.