Speculative Decoding for Fast Generation
Speculative decoding is an advanced technique that parallelizes the token-generation (decode) phase by using a smaller, faster model to predict multiple future tokens in advance, then validating them with the main model in a single forward pass. Instead of generating one token at a time (sequential), speculative decoding generates several tokens per model forward pass, reducing decode latency by 30-50% depending on speculation accuracy. This is a production technique used by OpenAI, Google (Gemini), and others but requires careful tuning and is only worth implementing when decode latency is your bottleneck.
The Decode Bottleneck
Recall from earlier articles: during decode, the model generates one new token per forward pass. Each forward pass requires loading the model weights and computing attention. On a 7B-parameter model on an A100 GPU, this takes 8-12 milliseconds per token. For a 1000-token response, decode takes 8-12 seconds alone (longer than prefill). This is the decode latency bottleneck.
The insight: decode is memory-bound, not compute-bound. The model cannot fully utilize GPU compute while waiting for weights to be loaded from VRAM. If you could predict the next 4-8 tokens speculatively and validate them in one forward pass, you would:
- Amortize weight-loading cost across multiple tokens.
- Batch compute across multiple token predictions.
- Reduce per-token latency by 3-8x (depending on batch size).
Speculative decoding does exactly this.
How Speculative Decoding Works
The algorithm has two phases: speculation and validation.
Speculation Phase
A smaller, faster model (the "draft model") generates K candidate next tokens (e.g., 4-8 tokens ahead). The draft model is much smaller (1B-3B) than the main model (7B-70B) and runs much faster (~2ms per token vs. 10ms). The draft model does not need to be high-quality; it just needs to predict plausible continuations.
Time 0ms: Draft model predicts tokens [T1, T2, T3, T4]
Time 8ms: Draft model finishes (4 tokens * 2ms each)
Validation Phase
The main model receives the draft tokens and evaluates whether each one was correct. It does this by:
- Processing the original context + all K draft tokens in a single forward pass.
- Comparing the main model's logits (probability distribution) for each position against the draft model's predictions.
- Accepting draft tokens where the main model's top-1 prediction matches the draft token.
- Rejecting and regenerating tokens where the main model disagrees.
Draft predictions: [T1, T2, T3, T4]
Main model output: [M1, M2, M3, M4] (with confidences)
Comparison:
Position 1: Draft=T1, Main=M1 → Match? Yes, accept T1
Position 2: Draft=T2, Main=M2 → Match? Yes, accept T2
Position 3: Draft=T3, Main=M3 → Match? No, use M3
Position 4: Draft=T4, Main=? (stop here, respeculate from M3)
Output: [T1, T2, M3]
Next iteration: Draft speculates from [T1, T2, M3]...
The key advantage: even if only 2-3 of 4 draft tokens are accepted, you generated 2-3 tokens with the cost of a single main-model forward pass (vs. 2-3 separate passes).
Implementation: Draft and Main Model Setup
Below is a simplified example of speculative decoding using vLLM:
from vllm import LLM, SamplingParams
import time
class SpeculativeDecoder:
"""
Generates tokens using speculative decoding.
Wraps a main model and a smaller draft model.
"""
def __init__(self, main_model_name: str, draft_model_name: str):
self.main_llm = LLM(model=main_model_name, gpu_memory_utilization=0.6)
self.draft_llm = LLM(model=draft_model_name, gpu_memory_utilization=0.3)
def speculative_generate(self, prompt: str, max_tokens: int = 256,
num_draft_tokens: int = 4):
"""
Generate tokens using draft speculations and main validation.
"""
generated_tokens = []
current_prompt = prompt
while len(generated_tokens) < max_tokens:
# Phase 1: Speculation
t0 = time.perf_counter()
draft_params = SamplingParams(
max_tokens=min(num_draft_tokens, max_tokens - len(generated_tokens)),
temperature=0.7
)
draft_output = self.draft_llm.generate(
prompts=[current_prompt],
sampling_params=draft_params
)
draft_tokens = draft_output[0].outputs[0].token_ids
t1 = time.perf_counter()
draft_time_ms = (t1 - t0) * 1000
# Phase 2: Validation by main model
# In a real implementation, we would:
# 1. Append each draft token
# 2. Get main model's logit distribution
# 3. Compare against draft tokens
# 4. Accept matches, reject mismatches
# Simplified: just validate with main model
t0 = time.perf_counter()
main_params = SamplingParams(max_tokens=len(draft_tokens))
main_output = self.main_llm.generate(
prompts=[current_prompt + "".join(
self.draft_llm.get_tokenizer().decode([t]) for t in draft_tokens
)],
sampling_params=main_params
)
t1 = time.perf_counter()
main_time_ms = (t1 - t0) * 1000
main_tokens = main_output[0].outputs[0].token_ids
# Accept matching tokens
accepted = 0
for i, (draft_t, main_t) in enumerate(zip(draft_tokens, main_tokens)):
if draft_t == main_t:
generated_tokens.append(draft_t)
accepted += 1
else:
generated_tokens.append(main_t)
break # Stop accepting; respeculate from here
current_prompt += "".join(
self.draft_llm.get_tokenizer().decode([t]) for t in generated_tokens[-accepted:]
)
print(f"Draft: {draft_time_ms:.1f}ms, "
f"Validate: {main_time_ms:.1f}ms, "
f"Accepted: {accepted}/{len(draft_tokens)} tokens")
return generated_tokens
# Example usage
decoder = SpeculativeDecoder(
main_model_name="meta-llama/Llama-2-7b-hf",
draft_model_name="meta-llama/Llama-2-3b-hf" # 3B as draft
)
prompt = "Explain machine learning in 200 words."
tokens = decoder.speculative_generate(prompt, max_tokens=200, num_draft_tokens=4)
print(f"Generated {len(tokens)} tokens.")
Expected output with a 7B main and 3B draft:
Draft: 6.0ms, Validate: 12.0ms, Accepted: 3/4 tokens
Draft: 6.0ms, Validate: 12.0ms, Accepted: 2/4 tokens
Draft: 6.0ms, Validate: 12.0ms, Accepted: 4/4 tokens
...
Generated 200 tokens.
Total time: ~3-4 seconds (vs. 20+ seconds without speculation)
Speculative Decoding Trade-offs
Speculative decoding is powerful but comes with challenges:
| Aspect | Pro | Con |
|---|---|---|
| Latency reduction | 30-50% decode speedup | Complex implementation |
| Memory cost | Draft model is small (1-3B) | 2 models in VRAM simultaneously |
| Accuracy | Validation is by main model; accuracy unchanged | Draft quality affects acceptance rate |
| Speculation success | Works best when draft model is a smaller version of main | Fails if draft is too weak (low acceptance rate) |
Speculative decoding is only worth implementing if:
- Decode is your bottleneck (not prefill, not memory-bound). Profile first.
- You can afford two models in VRAM (main + draft). A 7B + 3B requires ~30 GB on GPU.
- The draft model is a scaled-down version of the main model. Using an unrelated, weaker model gives poor acceptance rates.
When to Use Speculative Decoding
- Long-form generation (blog posts, summaries, code): decode dominates, large speedup possible.
- Latency-critical applications (streaming chat, real-time): every 10ms reduction matters.
- Cost-sensitive at-scale (reduce per-inference GPU time → reduce total GPU hours → reduce costs).
Do NOT use for:
- Short responses (< 100 tokens): prefill and overhead dominate.
- Memory-constrained setups: two models will OOM.
- Interactive single-request inference on edge devices: draft model overhead not worth it.
Measuring Speculation Effectiveness
Track these metrics to validate speculative decoding is helping:
def measure_speculation_impact(prompt: str, max_tokens: int):
"""
Measure: decode latency without vs. with speculation.
"""
# Without speculation
t0 = time.perf_counter()
output_baseline = llm_main.generate(
[prompt],
SamplingParams(max_tokens=max_tokens)
)
time_baseline = (time.perf_counter() - t0) * 1000
# With speculation
t0 = time.perf_counter()
output_speculative = decoder.speculative_generate(
prompt,
max_tokens=max_tokens,
num_draft_tokens=4
)
time_speculative = (time.perf_counter() - t0) * 1000
speedup = time_baseline / time_speculative
print(f"Baseline: {time_baseline:.0f}ms")
print(f"Speculative: {time_speculative:.0f}ms")
print(f"Speedup: {speedup:.2f}x")
# Measure acceptance rate
total_draft = 0 # Count across all iterations
total_accepted = 0
acceptance_rate = total_accepted / total_draft if total_draft > 0 else 0
print(f"Acceptance rate: {acceptance_rate * 100:.1f}%")
Target: 70%+ acceptance rate for 30%+ speedup.
Key Takeaways
- Speculative decoding parallelizes decode using a smaller draft model to predict multiple tokens ahead.
- 30-50% latency reduction for long-form generation (main bottleneck is decode).
- Requires two models in VRAM: draft (1-3B) + main (7B+); complex infrastructure.
- Acceptance rate critical: 70%+ acceptance required for positive ROI.
- Profile decode latency first: only worthwhile if decode time >> prefill time.
Frequently Asked Questions
What is the best draft model for a given main model?
Ideally, a smaller version of the same model (e.g., 7B main → 3B draft). Some frameworks (e.g., LLaMA family) provide scaled versions. For proprietary models without a smaller variant, distill one (expensive) or use a general smaller model (3B Mistral, etc.) with reduced acceptance rate.
Can speculative decoding degrade output quality?
No, because the main model validates every token. Draft predictions are just a hint; the main model always produces the final output. Only the speed improves, not the quality.
What draft/main model size ratio is optimal?
Empirically, 1:2 to 1:3 ratio (e.g., 3B draft for 7B main) gives good acceptance rates (~70-80%) while keeping draft generation fast. Smaller draft = faster but lower acceptance; larger draft = slower but higher acceptance. Tune for your latency target.
Does speculative decoding work with KV caching?
Yes. KV cache applies to both draft and main model. The draft model's KV cache is separate from the main model's; validation does not reuse the draft KV cache (you must recompute for the main model).