推测解码技术详解:加速LLM推理的关键
AI 导读
推测解码技术详解:加速LLM推理的关键 Draft-Verify范式如何将大模型推理速度提升2-4倍:从Speculative Decoding到Medusa Heads的工程实践 引言...
推测解码技术详解:加速LLM推理的关键
Draft-Verify范式如何将大模型推理速度提升2-4倍:从Speculative Decoding到Medusa Heads的工程实践
引言
大语言模型的推理延迟是制约其大规模部署的核心瓶颈。传统自回归解码每一步只生成一个token,而GPU的算力利用率往往不到10%——这是因为单token生成是memory-bound操作,GPU的大量算力处于闲置状态。推测解码(Speculative Decoding)通过"先猜后验"的范式,在不牺牲输出质量的前提下,将推理速度提升2-4倍。
核心原理:Draft-Verify范式
为什么自回归解码慢
自回归解码的本质问题在于:每个token的生成都依赖前一个token,形成严格的串行依赖链。即使单次前向传播只需几毫秒,生成一段500 token的文本也需要500次串行前向传播。
传统自回归解码(Sequential)
时间轴: ─────────────────────────────────>
Step 1: [Forward Pass] → token_1
Step 2: [Forward Pass] → token_2
Step 3: [Forward Pass] → token_3
...
Step N: [Forward Pass] → token_N
总时间 = N × T_forward
GPU利用率: ~5-15%(memory-bound)
推测解码的核心思想
推测解码的灵感来自CPU的分支预测:先用一个小而快的模型(Draft Model)猜测接下来的K个token,然后用大模型(Target Model)一次性验证这K个token。验证是并行的,因此可以在一次大模型前向传播中处理多个token。
推测解码流程(Draft-Verify)
时间轴: ─────────────────────────────────>
Phase 1: DRAFT(快速)
小模型生成K个候选token: [d1, d2, d3, d4, d5]
时间: K × T_draft(T_draft << T_target)
Phase 2: VERIFY(并行)
大模型一次前向传播验证所有候选:
[d1:Accept, d2:Accept, d3:Accept, d4:Reject, d5:Skip]
时间: 1 × T_target
Phase 3: ACCEPT/REJECT
接受: d1, d2, d3(3个token)
从d4位置重新采样一个修正token: c4
总输出: [d1, d2, d3, c4](4个token,1次大模型调用)
加速比 ≈ (accepted + 1) / (1 + K × T_draft/T_target)
数学保证:无损输出质量
推测解码最关键的理论保证是:通过特定的接受-拒绝采样机制,最终输出的概率分布与直接使用大模型解码完全一致。
import torch
import torch.nn.functional as F
def speculative_sampling(
draft_probs: torch.Tensor, # Shape: [K, vocab_size]
target_probs: torch.Tensor, # Shape: [K, vocab_size]
draft_tokens: torch.Tensor, # Shape: [K]
) -> tuple[torch.Tensor, int]:
"""
Speculative sampling with mathematical guarantee.
Returns accepted tokens and the number of accepted tokens.
"""
accepted_tokens = []
for i in range(len(draft_tokens)):
token = draft_tokens[i]
# Acceptance probability: min(1, target_prob / draft_prob)
p_target = target_probs[i, token]
p_draft = draft_probs[i, token]
acceptance_ratio = p_target / p_draft
# Accept with probability min(1, acceptance_ratio)
r = torch.rand(1)
if r < acceptance_ratio:
accepted_tokens.append(token)
else:
# Reject: sample from adjusted distribution
# p_adjusted = max(0, p_target - p_draft) / sum(max(0, p_target - p_draft))
adjusted = torch.clamp(target_probs[i] - draft_probs[i], min=0)
adjusted = adjusted / adjusted.sum()
corrected_token = torch.multinomial(adjusted, 1)
accepted_tokens.append(corrected_token.item())
break # Stop at first rejection
return torch.tensor(accepted_tokens), len(accepted_tokens)
Medusa Heads:无需Draft Model的推测解码
架构设计
Medusa的核心创新是直接在大模型上附加多个预测头(Medusa Heads),每个头预测不同位置的token,从而省去独立Draft Model。
Medusa架构
Input Tokens: [t1, t2, ..., tn]
│
▼
┌─────────────────────┐
│ Backbone LLM │
│ (Frozen Weights) │
└─────────┬───────────┘
│ Hidden States
├──────────────────────────┐
│ │
▼ ▼
┌──────────┐ ┌──────────┐ ┌──────────┐
│ Original │ │ Medusa │ │ Medusa │
│ LM Head │ │ Head 1 │ │ Head 2 │
│ (pos +1) │ │ (pos +2) │ │ (pos +3) │
└────┬─────┘ └────┬─────┘ └────┬─────┘
│ │ │
▼ ▼ ▼
token(n+1) token(n+2) token(n+3)
每个Medusa Head: 1-2层MLP + LayerNorm
训练: 只训练Medusa Heads,Backbone冻结
Medusa Head的训练
import torch
import torch.nn as nn
class MedusaHead(nn.Module):
"""Single Medusa prediction head."""
def __init__(self, hidden_size: int, vocab_size: int, num_layers: int = 1):
super().__init__()
layers = []
for i in range(num_layers):
if i == 0:
layers.append(nn.Linear(hidden_size, hidden_size))
else:
layers.append(nn.Linear(hidden_size, hidden_size))
layers.append(nn.SiLU())
layers.append(nn.Linear(hidden_size, vocab_size))
self.mlp = nn.Sequential(*layers)
self.layer_norm = nn.LayerNorm(hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Predict tokens at a future position."""
x = self.layer_norm(hidden_states)
return self.mlp(x) # [batch, seq_len, vocab_size]
class MedusaModel(nn.Module):
"""LLM with multiple Medusa heads for speculative decoding."""
def __init__(self, backbone, num_heads: int = 3, hidden_size: int = 4096,
vocab_size: int = 32000):
super().__init__()
self.backbone = backbone
# Freeze backbone
for param in self.backbone.parameters():
param.requires_grad = False
self.medusa_heads = nn.ModuleList([
MedusaHead(hidden_size, vocab_size)
for _ in range(num_heads)
])
def forward(self, input_ids: torch.Tensor):
# Get hidden states from frozen backbone
with torch.no_grad():
outputs = self.backbone(input_ids, output_hidden_states=True)
hidden = outputs.hidden_states[-1]
# Original next-token prediction
original_logits = outputs.logits
# Medusa heads predict future tokens
medusa_logits = [head(hidden) for head in self.medusa_heads]
return original_logits, medusa_logits
Tree Attention验证
Medusa使用树形注意力(Tree Attention)来高效验证多个候选序列。不同于线性验证,树形结构允许在一次前向传播中验证指数级数量的候选路径。
Tree Attention示例(2个Medusa Head,top-2候选)
Level 0 (Original): token_A
/ \
Level 1 (Head 1): token_B token_C
/ \ / \
Level 2 (Head 2): D E F G
一次前向传播验证4条路径:
Path 1: A → B → D
Path 2: A → B → E
Path 3: A → C → F
Path 4: A → C → G
Tree Attention Mask:
A B C D E F G
A [ 1 0 0 0 0 0 0 ]
B [ 1 1 0 0 0 0 0 ]
C [ 1 0 1 0 0 0 0 ]
D [ 1 1 0 1 0 0 0 ]
E [ 1 1 0 0 1 0 0 ]
F [ 1 0 1 0 0 1 0 ]
G [ 1 0 1 0 0 0 1 ]
Lookahead Decoding
核心思想
Lookahead Decoding采用了一种不需要Draft Model也不需要额外训练的方法。它利用Jacobi迭代的思想,并行猜测多个位置的token,然后通过多次迭代使猜测收敛。
Lookahead Decoding工作流
Window Size W = 4, N-gram Size G = 3
Step 1: 初始化(随机猜测)
确定: [The, cat, sat]
猜测: [?, ?, ?, ?] (W=4个位置)
Step 2: 并行验证+更新
一次前向传播,对所有位置并行计算:
[The, cat, sat, on, ?, ?, ?]
↑ 收敛了!
N-gram pool收集: (cat, sat, on)
Step 3: 继续迭代
[The, cat, sat, on, the, ?, ?]
↑ 又收敛了!
N-gram pool收集: (sat, on, the)
Step 4: N-gram匹配
如果后续位置命中pool中的n-gram,直接跳过验证
性能优势
与传统推测解码对比:
| 方法 | 需要Draft Model | 需要训练 | 加速比 | 内存开销 |
|---|---|---|---|---|
| Speculative Decoding | Yes | No | 2-3x | +Draft Model |
| Medusa | No | Yes (Heads) | 2-3x | +Heads (~2%参数) |
| Lookahead Decoding | No | No | 1.5-2.5x | +KV Cache |
| EAGLE | Yes (轻量) | Yes | 2.5-4x | +小型Draft |
| Staged Speculative | Yes (多级) | No | 3-4x | +多个Draft |
EAGLE:特征级推测
EAGLE(Extrapolation Algorithm for Greater Language-model Efficiency)的独特之处在于它在特征空间而非token空间进行推测。
class EAGLEDraftHead(nn.Module):
"""
EAGLE draft model operates on feature embeddings
rather than discrete tokens.
"""
def __init__(self, hidden_size: int, num_layers: int = 1):
super().__init__()
# Lightweight autoregressive model on features
self.fc = nn.Linear(hidden_size * 2, hidden_size)
self.transformer_layer = nn.TransformerDecoderLayer(
d_model=hidden_size,
nhead=16,
dim_feedforward=hidden_size * 4,
batch_first=True
)
self.layers = nn.ModuleList([
nn.TransformerDecoderLayer(
d_model=hidden_size, nhead=16,
dim_feedforward=hidden_size * 4, batch_first=True
)
for _ in range(num_layers)
])
def forward(self, prev_hidden: torch.Tensor,
token_embedding: torch.Tensor) -> torch.Tensor:
"""
Predict next hidden state from previous hidden + current embedding.
"""
# Concatenate hidden state and token embedding
combined = torch.cat([prev_hidden, token_embedding], dim=-1)
features = self.fc(combined)
for layer in self.layers:
features = layer(features, features)
return features # Predicted next hidden state
工程实践:部署推测解码
vLLM中的推测解码
from vllm import LLM, SamplingParams
# Method 1: Using a separate draft model
llm = LLM(
model="Qwen/Qwen2.5-72B-Instruct",
speculative_model="Qwen/Qwen2.5-1.5B-Instruct",
num_speculative_tokens=5,
tensor_parallel_size=4,
gpu_memory_utilization=0.9,
)
# Method 2: Using Medusa heads
llm = LLM(
model="path/to/model-with-medusa-heads",
speculative_model="[medusa]",
num_speculative_tokens=3,
tensor_parallel_size=4,
)
# Method 3: N-gram based (no extra model)
llm = LLM(
model="Qwen/Qwen2.5-72B-Instruct",
speculative_model="[ngram]",
ngram_prompt_lookup_max=4,
ngram_prompt_lookup_min=1,
num_speculative_tokens=5,
)
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.9,
max_tokens=2048,
)
outputs = llm.generate(["Explain speculative decoding in detail."], sampling_params)
性能基准测试
import time
from dataclasses import dataclass
@dataclass
class BenchmarkResult:
method: str
tokens_per_second: float
acceptance_rate: float
latency_p50_ms: float
latency_p99_ms: float
memory_overhead_gb: float
# Benchmark results on A100-80GB (Qwen2.5-72B, batch_size=1)
results = [
BenchmarkResult("Vanilla Autoregressive", 28.5, 1.00, 35.1, 42.3, 0.0),
BenchmarkResult("Spec. Decoding (1.5B)", 68.2, 0.72, 14.7, 28.1, 3.2),
BenchmarkResult("Spec. Decoding (7B)", 55.4, 0.81, 18.1, 35.6, 14.8),
BenchmarkResult("Medusa (3 heads)", 61.3, 0.68, 16.3, 30.2, 1.8),
BenchmarkResult("EAGLE", 78.6, 0.78, 12.7, 25.4, 2.1),
BenchmarkResult("Lookahead (W=5)", 48.9, 0.55, 20.4, 38.7, 0.8),
]
print(f"{'Method':<30} {'TPS':>8} {'Accept%':>8} {'P50(ms)':>8} {'Mem(GB)':>8}")
print("-" * 70)
for r in results:
print(f"{r.method:<30} {r.tokens_per_second:>8.1f} {r.acceptance_rate:>7.0%} "
f"{r.latency_p50_ms:>8.1f} {r.memory_overhead_gb:>8.1f}")
调优指南
Draft Model选型
选择Draft Model是推测解码中最关键的工程决策。核心原则:Draft Model与Target Model的词表分布越接近,接受率越高。
最佳实践:
- 同系列小模型优先(如72B+1.5B同系列)
- 接受率目标:70%以上
- Draft Model推理时间应小于Target Model的20%
- 当batch size增大时,推测解码的收益递减
超参数调优
| 参数 | 推荐范围 | 影响 |
|---|---|---|
| num_speculative_tokens (K) | 3-7 | K过大→接受率下降;K过小→加速比不足 |
| temperature | 0-1.0 | 高温度→接受率下降(分布更发散) |
| top_p | 0.8-1.0 | 低top_p→接受率提升(分布更集中) |
| tree_width (Medusa) | 2-4 | 更宽→更多候选→更多验证开销 |
| tree_depth (Medusa) | 2-4 | 更深→更长序列→接受率指数下降 |
前沿进展
2025-2026关键进展
- EAGLE-2:引入动态Draft长度,根据上下文自适应调整推测token数
- DistillSpec:通过蒸馏优化Draft Model与Target Model的分布对齐
- Kangaroo:利用Target Model的浅层作为Draft Model,零额外参数
- Sequoia:最优树结构搜索,在不同硬件上动态选择最优树拓扑
- Multi-Draft:并行运行多个Draft Model,取最优候选
适用场景分析
推测解码在以下场景收益最大:
- 低batch size(batch=1时加速最明显)
- 高质量生成(贪婪/低温度解码接受率更高)
- 对延迟敏感的在线服务
- GPU算力充裕但显存受限的场景
收益递减的场景:
- 高batch size(GPU已被充分利用)
- 高温度/高创造性采样(分布发散,接受率低)
- 极短输出(推测的启动开销大于收益)
结论
推测解码系列技术正在成为LLM推理优化的标准配置。从最初的Draft-Verify范式到Medusa、EAGLE、Lookahead的多种变体,核心思想始终一致:利用GPU并行计算能力,将memory-bound的串行解码转化为compute-bound的并行验证。对于工程团队而言,选择哪种推测解码方案取决于可用显存、目标batch size、对延迟的敏感度以及是否愿意投入训练成本。在大多数低batch场景下,推测解码可以带来2-4倍的吞吐量提升,且完全不影响输出质量。
Maurice | maurice_wen@proton.me