DeepSeek王炸来袭!原生稀疏注意力 Native Sparse Attention助力64K长上下文前向飙升9倍、反向极速6倍加速!

561 阅读3分钟

Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention 本论文的核心创新点在于提出了一种名为 NSA(Native Sparse Attention)的原生可训练稀疏注意力架构,它在保证模型性能的同时,大幅降低了长上下文建模的计算成本。其主要创新与核心概念可总结如下:

  1. 原生可训练的稀疏注意力
    与传统仅在推理阶段施加稀疏性的方法不同,NSA 将稀疏注意力直接融入训练过程,实现端到端优化,避免了后置稀疏可能带来的性能退化问题。这种设计使得模型在预训练时就能够学到最优的稀疏模式,从而提高了在长上下文任务中的表现。

  2. 分层 Token 建模策略
    NSA 采用了动态分层的稀疏策略,将原始的键值对经过三个映射过程处理:

    • Token 压缩:将连续的 token 聚合为块级表示,从而捕捉全局语义信息并大幅降低计算量。
    • Token 选择:在块内进行精细筛选,仅保留最重要的 token,确保关键信息不丢失,同时进一步减少计算量。
    • 滑动窗口:专门处理局部上下文,保证在进行全局压缩和选择的同时,不忽略局部细粒度信息。
      这种分层机制实现了全局感知与局部精度的平衡。
  3. 硬件对齐的优化设计
    NSA 在设计时充分考虑了现代 GPU 的硬件特性,通过块级内存访问、合理的循环调度以及专门在 Triton 上实现的核函数设计,最大化了 Tensor Core 的利用率,减少了内存访问瓶颈,进而在解码、前向和反向传播阶段均实现了显著的加速(例如在 64k 长上下文下,前向传播可达 9 倍、反向传播可达 6 倍加速)。

image.png 4. 高效的端到端训练能力
通过引入原生稀疏注意力,NSA 能够在预训练阶段直接优化稀疏模式,降低了长序列训练的计算成本,同时保证了模型在通用任务与长上下文任务上的性能,解决了现有稀疏方法主要关注推理而忽略训练效率的问题。

image.png 核心概念解释:

  • Token 压缩:将一段连续的 token 聚合为一个块级表示,通过一个可学习的映射(如 MLP),提取出该块的全局信息,减少需要计算的键值对数量,从而降低计算复杂度。
  • Token 选择:在每个块内计算重要性得分,仅保留排名靠前的 token,用以捕捉关键信息。这种策略既能保持模型对关键信息的敏感性,又能避免不必要的计算。
  • 滑动窗口注意力:为了不遗漏局部上下文信息,采用固定窗口机制确保当前查询能关注到其附近的 token,从而在全局信息压缩的同时,保留局部精细结构。
  • 硬件对齐优化:通过设计与现代 GPU 架构(如 Tensor Core 和高速 SRAM 访问)相匹配的核函数和数据加载策略,实现计算与内存访问的最优平衡,从而大幅提升训练与推理速度。

image.png 总体来看,NSA 通过在稀疏注意力设计上实现原生可训练、分层建模和硬件对齐优化,成功兼顾了模型性能与计算效率,为长上下文语言模型提供了一条高效且实用的路径

同系列: