🏎️ 投机采样 (Speculative Decoding):大模型推理的并行外挂

4 阅读3分钟

在自回归生成(Autoregressive Generation)模式下,大模型必须“逐字输出”,这导致了巨大的计算浪费。投机采样通过博弈论与并行计算的巧思,打破了这一瓶颈,实现了不损失精度的推理加速。


一、 核心痛点:为什么大模型“慢”?

大模型推理属于典型的 内存受限 (Memory-Bound) 任务:

  • 算力过剩:GPU 的计算能力(TFLOPS)极强,但在推理时,大部分时间都在等待从显存中读取数百 GB 的模型参数。
  • 低效率搬运:每生成一个 Token,都需要把全量参数从显存搬运到计算单元。对于 70B 模型,搬运一次参数仅为了算一个词,硬件利用率极低。

二、 核心哲学:先盲猜,后验证

投机采样的基本思想是:“用极小的代价预测未来,用极大的代价验证对错。”

1. 角色分配

  • 草稿模型 (Draft Model):一个小而快的模型(如 1B 或更小)。它的任务是快速“盲猜”接下来的 KK 个词。
  • 目标模型 (Target Model):你的主力大模型(如 70B)。它的任务是“审查”草稿模型的猜测。

2. 执行流程

  1. 草稿迭代:草稿模型连续运行 KK 次,生成一段建议序列(如:“北京是中国的首都”)。
  2. 并行验证:目标模型一次性将这段序列读入。利用 GPU 的闲置算力,目标模型可以在一个推理周期内判断这 KK 个词是否符合自己的逻辑。
  3. 接受与修正
    • 目标模型对比自己的概率分布。如果猜对了前 3 个词,它会接受这 3 个词。
    • 在猜错的第 4 个词位置,目标模型会给出自己的正确输出。
    • 废弃剩余无效猜测,开始下一轮投机。

三、 数学保证:无损生成的奥秘

投机采样不仅仅是加速,它在数学上是完全无损的。它使用了拒绝采样 (Rejection Sampling) 的变体:

如果草稿模型的预测概率为 q(x)q(x),目标模型的概率为 p(x)p(x)

  • p(x)q(x)p(x) \geq q(x),则 100% 接受该 Token。
  • p(x)<q(x)p(x) < q(x),则以 p(x)q(x)\frac{p(x)}{q(x)} 的概率接受。
  • 这种机制确保了最终生成的分布与直接运行目标模型得到的分布在统计上完全一致

四、 工业界主流变体

为了进一步优化,业界演进出了多种无需额外小模型的方案:

  1. Medusa (美杜莎): 在大模型顶层增加多个“解码头(Heads)”,每个头分别预测未来第 1, 2, ..., N 个位置的词。
    • 优点:无需加载额外的草稿模型,节省显存。
  2. Prompt Lookup Decoding: 直接从输入的 Context(如 RAG 提供的文档)中寻找匹配的片段作为猜测。
    • 优点:在处理文档摘要、翻译等任务时速度奇快。
  3. Eagle: 一种更强的投机方案,小模型不仅学习词序,还学习大模型的隐藏层特征(Hidden States),猜测准确率极高。

五、 技术总结与对比

维度传统自回归解码投机采样解码
计算模式串行(一个接一个)局部并行(批量验证)
显存带宽利用极低(搬运多,算得少)高(单次搬运验证多个词)
输出质量标准完全一致(无损)
加速比1x1.5x - 3.5x (取决于猜测准确率)

💡 核心洞察

投机采样的成功取决于 “猜测命中率”

  • 如果小模型和老师模型“心有灵犀”,生成速度会产生质的飞跃。
  • 在当前工业界(vLLM, TensorRT-LLM),这已成为实现毫秒级首字响应的核心黑科技。