07. Speculative Decoding
Standard autoregressive decoding generates one token at a time, executing the full model for each step. This sequential nature limits parallelism—the entire 70B parameter model runs to produce a single token. Speculative decoding breaks this bottleneck using a small "draft" model to generate candidates, then verifying multiple candidates in parallel using the larger "target" model.
The algorithm:
- Draft model generates
kcandidate tokens (typically 4-8) - Target model evaluates all
kcandidates in a single forward pass - Accepted tokens proceed; rejected tokens trigger resampling
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class SpeculativeDecoder:
def __init__(self, target_model, draft_model, draft_tokens=4):
self.target = target_model
self.draft = draft_model
self.draft_tokens = draft_tokens
def decode(self, input_ids, max_new_tokens):
generated = input_ids.clone()
while len(generated[0]) < max_new_tokens:
# Draft model generates candidates
draft_input = generated
with torch.no_grad():
draft_output = self.draft(draft_input)
draft_probs = torch.softmax(draft_output.logits[:, -1], dim=-1)
draft_tokens = torch.multinomial(
draft_probs,
num_samples=self.draft_tokens
).squeeze(-1)
# Target model evaluates candidates
target_input = torch.cat([generated, draft_tokens.unsqueeze(0).T], dim=-1)
with torch.no_grad():
target_output = self.target(target_input)
target_probs = torch.softmax(target_output.logits, dim=-1)
# Accept/reject tokens
for i in range(self.draft_tokens):
target_prob = target_probs[0, len(generated) + i, draft_tokens[0, i]]
threshold = torch.rand(1).item()
if target_prob.item() > threshold:
generated = torch.cat([generated, draft_tokens[0, i:i+1].unsqueeze(0)], dim=-1)
else:
# Resample from target distribution
new_token = torch.multinomial(target_probs[0, len(generated) - 1], 1)
generated = torch.cat([generated, new_token], dim=-1)
break
if len(generated[0]) >= max_new_tokens:
break
return generated
Speedup depends on the draft model's accuracy. A draft matching the target distribution exactly yields k× speedup. A draft matching 70% yields approximately 3× speedup on 4-token speculation.
Draft model selection criteria:
- Smaller than target: 1-7B parameters
- Similar architecture to target: enables KV cache sharing
- High baseline capability: bad drafts reduce acceptance rate
- Fast sampling: speculative overhead must not exceed parallelization gains
For Llama-family targets, Llama-7B works well as draft for Llama-2-70B. For coding models, CodeLlama-7B-Python drafts for CodeLlama-70B.
# Optimal configuration for 70B + 7B speculative decoding
decoder = SpeculativeDecoder(
target_model=AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b-hf",
torch_dtype=torch.float16,
device_map="auto"
),
draft_model=AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
device_map="auto"
),
draft_tokens=4
)
Failure mode: Draft models with different tokenizers cause acceptance rate to collapse. Ensure tokenizer compatibility or use a shared tokenizer.
# Verify tokenizer compatibility
assert target_model.config.vocab_size == draft_model.config.vocab_size
assert target_tokenizer.get_vocab() == draft_tokenizer.get_vocab()
Implement speculative decoding with your target and draft models. Vary draft_tokens from 2 to 8. Plot acceptance rate and tokens-per-second for each configuration.