08. Speculative Decoding for Reasoning
Speculative decoding accelerates LLM inference by using a small "draft" model to suggest tokens, then verifying multiple tokens in parallel with the large "target" model. For reasoning workloads, this technique offers significant speedups with modest implementation complexity.
The Speculative Decoding Principle
Standard decoding generates tokens sequentially—each token depends on all previous tokens. Speculative decoding breaks this dependency by generating candidate tokens with a fast draft model, then verifying them in parallel with the target model.
The key insight: if the draft model is right, you get multiple tokens for the cost of one verification pass. If wrong, you fall back to target model generation but have lost only the draft computation.
# Speculative decoding flow
def speculative_decode(target_model, draft_model, prompt, gamma=4):
"""Generate with speculative decoding, gamma draft tokens per step"""
# Draft phase: generate candidates
draft_tokens = []
for _ in range(gamma):
next_token = draft_model.forward(draft_tokens)
draft_tokens.append(next_token)
if next_token == EOS:
break
# Verify phase: parallel verification
combined = prompt + draft_tokens
target_logits = target_model.forward(combined)
# Accept/reject each draft token
accepted = []
for i, draft_tok in enumerate(draft_tokens):
target_prob = softmax(target_logits[i])[draft_tok]
draft_prob = draft_model.get_prob(draft_tok)
# Accept if draft probability exceeds threshold
if random.random() < min(1, target_prob / draft_prob):
accepted.append(draft_tok)
else:
# Reject and sample from target distribution
break
return accepted
Applicability to Reasoning Models
Reasoning models have characteristics that both help and hurt speculative decoding:
- Helpful: Reasoning chains often have predictable patterns (decomposition steps, verification steps). A draft model can learn these patterns.
- Challenging: Long reasoning chains increase the chance of divergence, reducing acceptance rate.
For reasoning workloads, tune gamma based on chain predictability. Short, structured reasoning chains (math proofs) work well with gamma=4-8. Longer, exploratory chains (creative problem-solving) may need gamma=2-4.
Draft Model Training
The draft model must be small enough to be fast but accurate enough to have high acceptance rate. Options:
- Quantized target model: Use INT4 version of R1 as draft; 93% acceptance rate typical
- Smaller pretrained model: Train a 1-3B model specifically for draft generation
- Speculative sampling from same model: Train draft as thinner version of target
# Draft model selection criteria
draft_config = {
"param_count": 3e9, # 3B parameters for draft
"layers": 24, # Reduced from R1's 61
"heads": 32, # Reduced from R1's 128
"quantization": "int4",
"expected_acceptance_rate": 0.85, # Target acceptance rate
}
# Training: use target model outputs as training data
def train_draft(target_outputs, draft_model):
"""Train draft model on target model reasoning chains"""
for reasoning_chain in target_outputs:
draft_loss = draft_model.train_step(reasoning_chain)
return draft_loss
Performance Expectations
Measured speedups depend on acceptance rate and draft model speed:
# Speedup calculation
target_time_per_token = 100 # ms
draft_time_per_token = 5 # ms (20x faster)
gamma = 4
# Best case: all draft tokens accepted
speedup = gamma * target_time_per_token / (
draft_time_per_token * gamma + target_time_per_token
)
# = 400 / 120 = 3.3x speedup
# Typical case: 85% acceptance
effective_gamma = 3.4 # 85% of 4
speedup = effective_gamma * target_time_per_token / (
draft_time_per_token * effective_gamma + target_time_per_token
)
# = 340 / 112 = 3.0x speedup
Implement speculative decoding for a reasoning task. Measure acceptance rate for different gamma values and estimate the speedup. Identify what causes rejections and whether you can improve the draft model.