SpargeAttn(稀疏注意力算子) 如何让大模型推理跑起来?
- 研究表明,SpargeAttn 是一种有效的稀疏注意力机制,可以显著加速大型模型的推理过程,同时几乎不损失性能。
- 它通过两阶段在线过滤器减少计算量,结合 Hilbert 曲线置换和 8-bit 量化进一步提升效率。
- 实验数据显示,在 Llama3.1、CogvideoX 和 Mochi 等模型上,SpargeAttn 实现了 1.83 倍到 5 倍的加速,稀疏性高达 0.54。
简介
SpargeAttn 是由清华大学研究团队提出的最新研究成果,旨在解决大型 AI 模型在推理阶段高计算成本的问题。推理是模型实际应用时的关键阶段,但传统注意力机制的计算复杂度随序列长度平方增长,导致延迟显著增加。SpargeAttn 通过引入稀疏和量化的注意力机制,显著降低了计算资源消耗,同时保持模型性能。
技术细节
SpargeAttn 的核心是两阶段在线过滤器:
- 第一阶段:稀疏块在线预测 通过选择性压缩相似 token 块,快速预测注意力图中不重要的区域,跳过不必要的矩阵乘法计算。
- 第二阶段:稀疏 Warp 在线 Softmax 进一步识别注意力值足够小的部分,跳过相关计算,无需额外开销。
此外,SpargeAttn 还采用了:
- Hilbert 曲线置换:重新排列 3D 视觉 token,增加块内相似性,提升稀疏性。
- 8-bit 量化:与 SageAttention 框架结合,进一步加速计算。
实验结果
实验在多个模型和任务上验证了 SpargeAttn 的效果:
- 在 Llama3.1(128K 序列长度)上,SpargeAttn 稀疏性达 0.54,速度提升至 708.1 TOPS,性能指标几乎无损。
- 在 CogvideoX 和 Mochi 视频生成模型上,分别实现了 2.5 倍到 5 倍的加速,视频质量保持稳定。
详细报告
引言:注意力机制与计算挑战
在大型 AI 模型中,注意力机制是处理自然语言处理、图像生成和视频生成等任务的核心组件。它允许模型聚焦于输入数据中的关键部分,从而理解上下文并生成准确输出。然而,传统注意力机制的计算复杂度为 (O(n^2)),其中 (n) 是序列长度。随着序列长度增加(如视频生成和语言模型中达到 45K-128K),计算成本呈平方级增长,推理延迟显著增加,尤其在资源受限的环境中成为瓶颈。
SpargeAttn 的解决方案
SpargeAttn 是一种通用的稀疏和量化注意力机制,旨在加速各类模型的推理过程。其核心创新在于两阶段在线过滤器,结合 Hilbert 曲线置换和 8-bit 量化,实现了速度与性能的双赢。
两阶段在线过滤器
-
稀疏块在线预测(Sparse Block Online Prediction)
- 步骤:首先计算 Query (Q) 和 Key (K) 矩阵中每个块内 token 的相似度(使用均值余弦相似度)。如果块内相似度高于阈值 (\theta),则将该块压缩为一个 token,计算块内 token 的均值。
- 效果:压缩后快速构建注意力图 (P̂),选择 Top-τ 比例的块作为重要区域,在稀疏掩码 (M_g) 中标记为 1,跳过其余计算。
- 开销:实验显示,对于 128K 序列长度,预测开销仅为 8.764 毫秒,相对于全注意力计算的 1696.2 毫秒,占比仅 0.516%。
-
稀疏 Warp 在线 Softmax
- 步骤:在 FlashAttention 过程中,识别注意力图中足够小的值。如果 (P̃ {ij}) 中的所有值接近于零,则跳过 (P̃{ij}V_j) 的计算。
- 效果:通过比较局部和全局最大值,设置阈值 (\lambda < 0),进一步减少矩阵乘法运算,无需额外开销。
技术细节:Hilbert 曲线置换和 8-bit 量化
-
Hilbert 曲线置换:针对 3D 视觉 token (Q, K, V \in \mathbb{R}^{T \times H \times W \times d}),使用 Hilbert 曲线填充 3D 空间后展平为 (\mathbb{R}^{L \times d})((L = T \times H \times W))。这保持局部性,增加相邻 token 的相似性,提升稀疏性。实验对比了 Random、Rowmajor、Timemajor 和 HilbertCurve,HilbertCurve 在 CogvideoX 和 Mochi 上分别提升了 Sim-q 和 Sim-k,稀疏性达 0.265 和 0.392。
-
表 4:Hilbert 曲线效果(Mochi)
方法 Sim-q ↑ Sim-k ↑ L1 ↓ 稀疏性 ↑ Random 0.321 0.019 0.0414 0.048 Rowmajor 0.551 0.390 0.0307 0.363 Timemajor 0.514 0.367 0.0342 0.338 HilbertCurve 0.572 0.479 0.0389 0.392
-
-
8-bit 量化:SpargeAttn 集成到 SageAttention 框架中,利用 8-bit 量化加速计算。实验显示,在 CogvideoX 上,SpargeAttn 比 SageAttn 单独使用减少了 15 秒生成延迟(53 秒 vs 68 秒)。
实验数据与性能
研究团队在文本、图像和视频生成模型上进行了广泛实验,验证 SpargeAttn 的加速效果和性能保持:
-
表 1:端到端指标和性能(部分数据,完整见论文)
模型 (序列长度) 注意力 (稀疏性) 速度 (TOPS) ↑ 指标 (任务特定) Llama3.1 (128K) SpargeAttn (0.54) 708.1 WikiText (Ppl.): 6.020, Longbench: 39.058, NIAH: 0.909 CogvideoX (17K) SpargeAttn (0.46) 507.9 CLIPSIM: 0.1798, VQA-a: 78.276, FScore: 5.030 Mochi (22K) SpargeAttn (0.47) 582.4 CLIPSIM: 0.1720, VQA-a: 54.179, FScore: 1.807 -
加速效果:在 Mochi 模型上实现了 1.83 倍加速(L40 GPU),在各种生成任务中比现有密集和稀疏注意力模型快 2.5 倍到 5 倍。
-
性能保持:与 Full-Attention 相比,SpargeAttn 在各种模型中几乎没有造成最终性能指标的损失。
-
稀疏性分析:在 Llama3.1 的 NeedleInAHaystack 任务中,SpargeAttn 稀疏性为 0.54,结合 (M_g) 和 (M_{pv}) 掩码达到 54%。
-
表 2:端到端生成延迟
模型 GPU Original (s) SageAttn (s) SpargeAttn (s) CogvideoX RTX4090 87 68 53 Mochi L40 1897 1544 1037 -
表 6:稀疏性分析(Llama3.1, 128K, NeedleInAHaystack)
策略 稀疏性 only (M_{g}) 51.2% only (M_{pv}) 27.7% (M {g}) + (M{pv}) 54%
应用前景与未来发展
SpargeAttn 的通用性使其在多个领域具有广阔应用前景:
- 加速大型模型推理:降低计算成本,便于部署到实际应用中。
- 支持更长序列长度:减少计算量后,可处理更长的输入序列,提升性能。
- 移动端部署:高效性使其有可能在移动设备上运行大型模型,实现更强大的移动 AI 应用。
未来发展方向包括:
- 自适应超参数选择:根据模型和数据集自动调整超参数。
- 更有效的 Token 压缩方法:提高压缩率同时保证信息完整。
- 与其他加速技术结合:如内核优化和量化,进一步提升效率。
结论
SpargeAttn 通过两阶段在线过滤器、Hilbert 曲线置换和 8-bit 量化,显著加速了语言、图像和视频生成模型的推理过程,同时几乎不损失性能。其在实验中展现出的 1.83 倍到 5 倍加速效果,特别是在长序列任务中的高稀疏性(高达 0.54),为高效 AI 应用提供了重要工具。