Transform 注意力机制:多头注意力、KV Cache、PagedAttention、FlashAttention

6 阅读1分钟

一、 注意力机制:多维语义的并行构建

注意力机制是 Transformer 的灵魂,其核心任务是在海量序列中实现高价值特征的选择性聚合。

1.1 核心公式与参数定义

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

  • QQ (Query)[1,dhead][1, d_{head}],当前生成 Token 的查询需求。
  • KK (Key)[N,dhead][N, d_{head}],历史 Token 的检索索引。
  • VV (Value)[N,dhead][N, d_{head}],历史 Token 的语义内容。
  • dk\sqrt{d_k}:缩放因子。当维度 dkd_k 增大时,点积结果方差随之增大,可能导致 Softmax 进入梯度饱和区。缩放可维持数值稳定性。

1.2 多头注意力 (MHA) 的语义演进

  • 多头逻辑:通过 hh 组不同的投影矩阵,模型在不同的子空间并行捕捉信息。
  • 语义涌现:头的语义由训练自发演化。底层头常捕捉语法和邻近关系,深层头则负责全局逻辑与长程指代。

1.3 结构精简:MHA \to MQA \to GQA

为了平衡性能,GQA (Grouped-Query Attention) 成为主流:

  • 原理:Query 头分组,每组共享一对 KV 头。
  • 收益:在保留 MHA 多维度特征提取能力的同时,将 KV Cache 显存占用降低为原来的 1/G1/GGG 为分组数)。

二、 KV Cache:自回归生成的动力源

在 LLM 逐字生成(Decoding)过程中,KV Cache 是将推理时间复杂度从 O(N2)O(N^2) 压低至 O(N)O(N) 的核心技术。

2.1 为什么缓存的是 K 和 V?

  • Q 的即时性:当前时刻生成的 QnowQ_{now} 仅用于查询过去。一旦当前词推理完成,QnowQ_{now} 随之失效,下一时刻会产生全新的 QnextQ_{next}
  • KV 的持久性:对于已经生成的 Token,其对应的 KK(用于被匹配相关性)和 VV(用于被提取特征)是其固有的语义特征,在后续所有推理步中保持不变。
  • 计算重用:缓存 K,VK, V 避免了每步推理都要重新执行线性映射(Wk,WvW_k, W_v),极大降低了计算冗余。

三、 PagedAttention:显存管理的“虚拟化”革命

PagedAttention 借鉴了操作系统的虚拟内存逻辑,彻底解决了长文本推理中的显存碎片化绝症。

3.1 核心机制:物理与逻辑解耦

  • 逻辑连续,物理离散:KV Cache 被切分为固定大小的 Blocks(如每 16 个 Token 一个分块)。
  • 块表 (Block Table):记录逻辑块索引与显存中不连续物理地址的映射关系。
  • 显存池化:所有物理块在空闲块池(Free Block Pool)中统一动态管理。

3.2 关键特性

  1. 按量分配:仅在生成新 Token 且当前块满时申请新块,显存浪费率从传统预分配的 60% 降至 4% 以下。
  2. 写时复制 (Copy-on-Write):多用户共享相同系统提示词(System Prompt)时指向同一物理块。仅在某个请求开始差异化生成时才执行物理复制。
  3. 换入换出 (Swapping):支持将非活跃缓存块置换到 CPU RAM(内存),赋予系统处理超大规模并发的弹性。

四、 FlashAttention:硬件感知的算子重构

FlashAttention 的本质是以计算冗余换取 I/O 节省,旨在突破 GPU 显存带宽(Memory Wall)的限制。

4.1 核心矛盾:HBM vs SRAM

  • HBM (显存):容量大但带宽窄(搬运慢)。
  • SRAM (高速缓存):带宽极宽但容量极小(就在计算单元旁边)。
  • 优化逻辑:将 Q,K,VQ, K, V 切块(Tiling)后放入 SRAM,使注意力矩阵的中间结果在 SRAM 内部生成并消耗,避免在慢速 HBM 中反复读写。

4.2 流式 Softmax (Online Softmax) 对比

Softmax 归一化依赖全局最大值,流式计算是实现硬件分块的前提:

  • 非流式 (Standard Three-pass):需三次遍历数据(找最大值 mm \to 求指数和 dd \to 归一化结果)。
  • 流式 (Online)增量修正。每读入一个新块,动态修正全局最大值 mm 和累加和 dddnew=doldemoldmnew+dblockd_{new} = d_{old} \cdot e^{m_{old} - m_{new}} + d_{block} 数据只需从显存搬运一次,计算在 SRAM 内部实时完成修正。

4.3 训练重计算 (Recomputation)

在反向传播时放弃存储巨大的注意力矩阵,转而现场重新计算。在现代硬件环境下,现场重算的耗时远低于从显存读取几十 GB 数据的时间


五、 全链路总结:推理工程的三大支柱

  1. GQA (结构层):精简 KV 数量,从模型架构层面减少内存足迹。
  2. PagedAttention (调度层):消除显存碎片,支持高并发与长上下文的弹性资源管理。
  3. FlashAttention (执行层):通过分块计算与流式算子,将 GPU 的有效吞吐量压榨到极限。