08. Speculative Decoding for Reasoning

Chapter 8 of 18 · 20 min

Speculative decoding accelerates LLM inference by using a small "draft" model to suggest tokens, then verifying multiple tokens in parallel with the large "target" model. For reasoning workloads, this technique offers significant speedups with modest implementation complexity.

The Speculative Decoding Principle

Standard decoding generates tokens sequentially—each token depends on all previous tokens. Speculative decoding breaks this dependency by generating candidate tokens with a fast draft model, then verifying them in parallel with the target model.

The key insight: if the draft model is right, you get multiple tokens for the cost of one verification pass. If wrong, you fall back to target model generation but have lost only the draft computation.

# Speculative decoding flow
def speculative_decode(target_model, draft_model, prompt, gamma=4):
    """Generate with speculative decoding, gamma draft tokens per step"""
    
    # Draft phase: generate candidates
    draft_tokens = []
    for _ in range(gamma):
        next_token = draft_model.forward(draft_tokens)
        draft_tokens.append(next_token)
        if next_token == EOS:
            break
    
    # Verify phase: parallel verification
    combined = prompt + draft_tokens
    target_logits = target_model.forward(combined)
    
    # Accept/reject each draft token
    accepted = []
    for i, draft_tok in enumerate(draft_tokens):
        target_prob = softmax(target_logits[i])[draft_tok]
        draft_prob = draft_model.get_prob(draft_tok)
        
        # Accept if draft probability exceeds threshold
        if random.random() < min(1, target_prob / draft_prob):
            accepted.append(draft_tok)
        else:
            # Reject and sample from target distribution
            break
    
    return accepted

Applicability to Reasoning Models

Reasoning models have characteristics that both help and hurt speculative decoding:

  • Helpful: Reasoning chains often have predictable patterns (decomposition steps, verification steps). A draft model can learn these patterns.
  • Challenging: Long reasoning chains increase the chance of divergence, reducing acceptance rate.

For reasoning workloads, tune gamma based on chain predictability. Short, structured reasoning chains (math proofs) work well with gamma=4-8. Longer, exploratory chains (creative problem-solving) may need gamma=2-4.

Draft Model Training

The draft model must be small enough to be fast but accurate enough to have high acceptance rate. Options:

  1. Quantized target model: Use INT4 version of R1 as draft; 93% acceptance rate typical
  2. Smaller pretrained model: Train a 1-3B model specifically for draft generation
  3. Speculative sampling from same model: Train draft as thinner version of target
# Draft model selection criteria
draft_config = {
    "param_count": 3e9,  # 3B parameters for draft
    "layers": 24,  # Reduced from R1's 61
    "heads": 32,  # Reduced from R1's 128
    "quantization": "int4",
    "expected_acceptance_rate": 0.85,  # Target acceptance rate
}

# Training: use target model outputs as training data
def train_draft(target_outputs, draft_model):
    """Train draft model on target model reasoning chains"""
    for reasoning_chain in target_outputs:
        draft_loss = draft_model.train_step(reasoning_chain)
    return draft_loss

Performance Expectations

Measured speedups depend on acceptance rate and draft model speed:

# Speedup calculation
target_time_per_token = 100  # ms
draft_time_per_token = 5  # ms (20x faster)
gamma = 4

# Best case: all draft tokens accepted
speedup = gamma * target_time_per_token / (
    draft_time_per_token * gamma + target_time_per_token
)
# = 400 / 120 = 3.3x speedup

# Typical case: 85% acceptance
effective_gamma = 3.4  # 85% of 4
speedup = effective_gamma * target_time_per_token / (
    draft_time_per_token * effective_gamma + target_time_per_token
)
# = 340 / 112 = 3.0x speedup
EXERCISE

Implement speculative decoding for a reasoning task. Measure acceptance rate for different gamma values and estimate the speedup. Identify what causes rejections and whether you can improve the draft model.