SpargeAttn(稀疏注意力算子) 如何让大模型推理跑起来?

293 阅读6分钟

SpargeAttn(稀疏注意力算子) 如何让大模型推理跑起来?

  • 研究表明,SpargeAttn 是一种有效的稀疏注意力机制,可以显著加速大型模型的推理过程,同时几乎不损失性能。
  • 它通过两阶段在线过滤器减少计算量,结合 Hilbert 曲线置换和 8-bit 量化进一步提升效率。
  • 实验数据显示,在 Llama3.1、CogvideoX 和 Mochi 等模型上,SpargeAttn 实现了 1.83 倍到 5 倍的加速,稀疏性高达 0.54。

简介

SpargeAttn 是由清华大学研究团队提出的最新研究成果,旨在解决大型 AI 模型在推理阶段高计算成本的问题。推理是模型实际应用时的关键阶段,但传统注意力机制的计算复杂度随序列长度平方增长,导致延迟显著增加。SpargeAttn 通过引入稀疏和量化的注意力机制,显著降低了计算资源消耗,同时保持模型性能。

技术细节

SpargeAttn 的核心是两阶段在线过滤器:

  • 第一阶段:稀疏块在线预测 通过选择性压缩相似 token 块,快速预测注意力图中不重要的区域,跳过不必要的矩阵乘法计算。
  • 第二阶段:稀疏 Warp 在线 Softmax 进一步识别注意力值足够小的部分,跳过相关计算,无需额外开销。

此外,SpargeAttn 还采用了:

  • Hilbert 曲线置换:重新排列 3D 视觉 token,增加块内相似性,提升稀疏性。
  • 8-bit 量化:与 SageAttention 框架结合,进一步加速计算。

实验结果

实验在多个模型和任务上验证了 SpargeAttn 的效果:

  • 在 Llama3.1(128K 序列长度)上,SpargeAttn 稀疏性达 0.54,速度提升至 708.1 TOPS,性能指标几乎无损。
  • 在 CogvideoX 和 Mochi 视频生成模型上,分别实现了 2.5 倍到 5 倍的加速,视频质量保持稳定。

详细报告

引言:注意力机制与计算挑战

在大型 AI 模型中,注意力机制是处理自然语言处理、图像生成和视频生成等任务的核心组件。它允许模型聚焦于输入数据中的关键部分,从而理解上下文并生成准确输出。然而,传统注意力机制的计算复杂度为 (O(n^2)),其中 (n) 是序列长度。随着序列长度增加(如视频生成和语言模型中达到 45K-128K),计算成本呈平方级增长,推理延迟显著增加,尤其在资源受限的环境中成为瓶颈。

SpargeAttn 的解决方案

SpargeAttn 是一种通用的稀疏和量化注意力机制,旨在加速各类模型的推理过程。其核心创新在于两阶段在线过滤器,结合 Hilbert 曲线置换和 8-bit 量化,实现了速度与性能的双赢。

两阶段在线过滤器
  1. 稀疏块在线预测(Sparse Block Online Prediction)

    • 步骤:首先计算 Query (Q) 和 Key (K) 矩阵中每个块内 token 的相似度(使用均值余弦相似度)。如果块内相似度高于阈值 (\theta),则将该块压缩为一个 token,计算块内 token 的均值。
    • 效果:压缩后快速构建注意力图 (P̂),选择 Top-τ 比例的块作为重要区域,在稀疏掩码 (M_g) 中标记为 1,跳过其余计算。
    • 开销:实验显示,对于 128K 序列长度,预测开销仅为 8.764 毫秒,相对于全注意力计算的 1696.2 毫秒,占比仅 0.516%。
  2. 稀疏 Warp 在线 Softmax

    • 步骤:在 FlashAttention 过程中,识别注意力图中足够小的值。如果 (P̃ {ij}) 中的所有值接近于零,则跳过 (P̃{ij}V_j) 的计算。
    • 效果:通过比较局部和全局最大值,设置阈值 (\lambda < 0),进一步减少矩阵乘法运算,无需额外开销。
技术细节:Hilbert 曲线置换和 8-bit 量化
  • Hilbert 曲线置换:针对 3D 视觉 token (Q, K, V \in \mathbb{R}^{T \times H \times W \times d}),使用 Hilbert 曲线填充 3D 空间后展平为 (\mathbb{R}^{L \times d})((L = T \times H \times W))。这保持局部性,增加相邻 token 的相似性,提升稀疏性。实验对比了 Random、Rowmajor、Timemajor 和 HilbertCurve,HilbertCurve 在 CogvideoX 和 Mochi 上分别提升了 Sim-q 和 Sim-k,稀疏性达 0.265 和 0.392。

    • 表 4:Hilbert 曲线效果(Mochi)

      方法Sim-q ↑Sim-k ↑L1 ↓稀疏性 ↑
      Random0.3210.0190.04140.048
      Rowmajor0.5510.3900.03070.363
      Timemajor0.5140.3670.03420.338
      HilbertCurve0.5720.4790.03890.392
  • 8-bit 量化:SpargeAttn 集成到 SageAttention 框架中,利用 8-bit 量化加速计算。实验显示,在 CogvideoX 上,SpargeAttn 比 SageAttn 单独使用减少了 15 秒生成延迟(53 秒 vs 68 秒)。

实验数据与性能

研究团队在文本、图像和视频生成模型上进行了广泛实验,验证 SpargeAttn 的加速效果和性能保持:

  • 表 1:端到端指标和性能(部分数据,完整见论文)

    模型 (序列长度)注意力 (稀疏性)速度 (TOPS) ↑指标 (任务特定)
    Llama3.1 (128K)SpargeAttn (0.54)708.1WikiText (Ppl.): 6.020, Longbench: 39.058, NIAH: 0.909
    CogvideoX (17K)SpargeAttn (0.46)507.9CLIPSIM: 0.1798, VQA-a: 78.276, FScore: 5.030
    Mochi (22K)SpargeAttn (0.47)582.4CLIPSIM: 0.1720, VQA-a: 54.179, FScore: 1.807
  • 加速效果:在 Mochi 模型上实现了 1.83 倍加速(L40 GPU),在各种生成任务中比现有密集和稀疏注意力模型快 2.5 倍到 5 倍。

  • 性能保持:与 Full-Attention 相比,SpargeAttn 在各种模型中几乎没有造成最终性能指标的损失。

  • 稀疏性分析:在 Llama3.1 的 NeedleInAHaystack 任务中,SpargeAttn 稀疏性为 0.54,结合 (M_g) 和 (M_{pv}) 掩码达到 54%。

  • 表 2:端到端生成延迟

    模型GPUOriginal (s)SageAttn (s)SpargeAttn (s)
    CogvideoXRTX4090876853
    MochiL40189715441037
  • 表 6:稀疏性分析(Llama3.1, 128K, NeedleInAHaystack)

    策略稀疏性
    only (M_{g})51.2%
    only (M_{pv})27.7%
    (M {g}) + (M{pv})54%
应用前景与未来发展

SpargeAttn 的通用性使其在多个领域具有广阔应用前景:

  • 加速大型模型推理:降低计算成本,便于部署到实际应用中。
  • 支持更长序列长度:减少计算量后,可处理更长的输入序列,提升性能。
  • 移动端部署:高效性使其有可能在移动设备上运行大型模型,实现更强大的移动 AI 应用。

未来发展方向包括:

  • 自适应超参数选择:根据模型和数据集自动调整超参数。
  • 更有效的 Token 压缩方法:提高压缩率同时保证信息完整。
  • 与其他加速技术结合:如内核优化和量化,进一步提升效率。
结论

SpargeAttn 通过两阶段在线过滤器、Hilbert 曲线置换和 8-bit 量化,显著加速了语言、图像和视频生成模型的推理过程,同时几乎不损失性能。其在实验中展现出的 1.83 倍到 5 倍加速效果,特别是在长序列任务中的高稀疏性(高达 0.54),为高效 AI 应用提供了重要工具。


引用