09. Attention Visualization

Chapter 9 of 18 · 25 min

Attention visualization translates complex attention patterns into interpretable representations. This chapter covers techniques for extracting, analyzing, and presenting attention data to understand model reasoning.

Understanding Attention Mechanisms

Transformer models use attention to weigh the importance of different input positions when generating each output. Attention weights indicate how much each input token "contributes" to understanding each output token.

In safety analysis, unusual attention patterns often reveal problematic behaviors:

# Attention pattern indicating potential manipulation
# Output position strongly attends to injection markers
# rather than legitimate content

def detect_anomalous_attention(attentions, input_tokens):
    """Identify unusual attention patterns"""
    
    anomalies = []
    
    for layer_idx, layer_attention in enumerate(attentions):
        # Average across heads for this layer
        avg_weights = layer_attention.mean(dim=0)[0]
        
        # Find positions with unusual focusing
        for out_idx in range(len(avg_weights)):
            # Compute how concentrated attention is
            attention_vector = avg_weights[out_idx]
            
            # High concentration on few tokens suggests targeting
            concentration = max(attention_vector).item()
            
            if concentration > 0.5:  # Arbitrary threshold
                max_pos = attention_vector.argmax().item()
                
                anomalies.append({
                    "layer": layer_idx,
                    "output_position": out_idx,
                    "focused_on": input_tokens[max_pos],
                    "concentration": concentration
                })
    
    return anomalies

Visualization Techniques

Effective visualization transforms attention tensors into human-interpretable displays:

Heatmaps show attention weight matrices with color intensity representing weight magnitude. Row labels show output positions; column labels show input positions.

# attention_visualizer.py
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def visualize_attention_matrix(attention_weights, input_tokens, output_tokens, 
                               layer_idx, save_path=None):
    """Create heatmap visualization of attention weights"""
    
    fig, ax = plt.subplots(figsize=(12, 10))
    
    # Create heatmap
    sns.heatmap(
        attention_weights,
        xticklabels=input_tokens,
        yticklabels=output_tokens,
        cmap="YlOrRd",
        annot=False,
        square=False,
        ax=ax
    )
    
    ax.set_xlabel("Input Tokens")
    ax.set_ylabel("Output Tokens")
    ax.set_title(f"Attention Weights - Layer {layer_idx}")
    
    plt.xticks(rotation=90, fontsize=8)
    plt.yticks(fontsize=8)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150)
        plt.close()
    else:
        plt.show()

def visualize_attention_flow(attentions_by_layer, tokens, save_path=None):
    """Visualize attention flow across layers"""
    
    num_layers = len(attentions_by_layer)
    fig, axes = plt.subplots(1, num_layers, figsize=(4*num_layers, 6))
    
    if num_layers == 1:
        axes = [axes]
    
    for layer_idx, attention in enumerate(attentions_by_layer):
        # Average across heads and sequence positions
        avg_attention = attention.mean(dim=0)[0].numpy()
        
        sns.heatmap(
            avg_attention,
            xticklabels=tokens,
            yticklabels=tokens,
            cmap="Blues",
            ax=axes[layer_idx],
            cbar=layer_idx == num_layers - 1
        )
        
        axes[layer_idx].set_title(f"Layer {layer_idx}")
        axes[layer_idx].tick_params(axis='x', rotation=90, labelsize=6)
        axes[layer_idx].tick_params(axis='y', labelsize=6)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150)
        plt.close()

Practical Analysis Patterns

Certain attention patterns indicate specific safety-relevant behaviors:

Safety boundary attention shows models attending to harmful content markers when deciding to refuse. This is appropriate if the marker reliably indicates actual harm.

Injection susceptibility shows attention focusing on injected markers rather than legitimate prompt structure. This indicates vulnerability to prompt injection.

Defensive patterns show attention to system prompt boundaries and instruction markers. Strong attention to these elements suggests dependable instruction following.

def analyze_safety_patterns(attentions, input_text):
    """Identify safety-relevant attention patterns"""
    
    tokens = tokenize(input_text)
    
    patterns = {
        "boundary_attention": [],
        "injection_indicators": [],
        "system_prompt_attention": []
    }
    
    for layer_idx, attention in enumerate(attentions):
        weights = attention[0]  # Remove batch dimension
        
        # Check for boundary token attention
        boundary_tokens = ["[INST]", "[/INST]", "Instructions:", "System:"]
        for token in boundary_tokens:
            if token in tokens:
                token_idx = tokens.index(token)
                # Measure attention to boundary from all positions
                boundary_attention = weights[:, :, token_idx].mean().item()
                patterns["boundary_attention"].append({
                    "layer": layer_idx,
                    "token": token,
                    "avg_attention": boundary_attention
                })
        
        # Check for injection patterns
        injection_markers = ["Ignore", "Previous instructions", "[SYSTEM]"]
        for marker in injection_markers:
            if marker in input_text:
                marker_start = input_text.index(marker)
                # Analyze whether subsequent tokens attend to marker
                patterns["injection_indicators"].append({
                    "layer": layer_idx,
                    "marker": marker,
                    "position": marker_start
                })
    
    return patterns

Integrating Visualization into Workflow

Effective interpretability requires integrating visualization into regular workflows:

  1. Baseline capture establishes normal attention patterns for expected inputs
  2. Anomaly detection flags attention patterns that deviate from baselines
  3. Correlation analysis links attention anomalies to output behaviors
  4. Root cause investigation uses visualizations to understand why anomalies occur
class AttentionMonitor:
    """Integrate attention monitoring into production systems"""
    
    def __init__(self, model, baseline_inputs):
        self.model = model
        self.baseline_patterns = self._capture_baselines(baseline_inputs)
    
    def _capture_baselines(self, inputs):
        """Establish baseline attention patterns"""
        patterns = []
        for input_text in inputs:
            attentions = self._extract_attentions(input_text)
            patterns.append(attentions)
        return patterns
    
    def monitor(self, input_text):
        """Check input against baselines, flag anomalies"""
        attentions = self._extract_attentions(input_text)
        anomalies = self._detect_anomalies(attentions)
        
        if anomalies:
            self._alert_and_log(input_text, anomalies)
        
        return anomalies
EXERCISE

Implement attention visualization for a local model processing test prompts including normal queries, potential jailbreaks, and injection attempts. Compare the attention patterns across these cases and identify distinguishing features.