论文笔记:Native Sparse Attention,效果比 MHA 还要好的稀疏注意力

147 阅读2分钟

留下阅读 (2025) Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention 的痕迹。

DeepSeek 根本停不下来,这次又发布了论文介绍一个新的注意力架构:NSA(Native Sparse Attention)。比 MHA(Multi Head Attention)成本更低、效果更好。

总览

NSA 不进行 token 层面的 attention,而是用三种压缩方法获得三组 kv 矩阵 进而获得三组注意力结果,然后加权求和。

  • 压缩。把 tokens 分段为块,各块压缩为单个 token
  • 选择。以块为单位找出重要的 tokens,只对这些 tokens 进行注意力
  • 滑动窗口。局部注意力

具体方法

对于长度为 tt 的序列,传统注意力机制:

ot=Attn(qt,k:t,v:t)\bold{o}_t=\text{Attn}(\bold{q}_t,\bold{k}_{:t},\bold{v}_{:t})
  • tt 是序列长度
Attn(qt,k:t,v:t)=i=1tαt,ivij=1tαt,j, αt,i=eqtTkidk\text{Attn}(\bold{q}_t,\bold{k}_{:t},\bold{v}_{:t})=\sum^t_{i=1}\frac{\alpha_{t,i}\bold{v}_i}{\sum^t_{j=1}\alpha_{t,j}},\ \alpha_{t,i}=e^{\frac{q_t^Tk_i}{\sqrt{d_k}}}
  • dkd_k 是 key 的维度数
  • αt,i\alpha_{t,i} 代表 qt\bold{q}_tki\bold{k}_i 之间的注意力权重

论文引入了映射 fKf_KfVf_V 用来构造新的 kv:

K~t=fK(k:t)\tilde{K}_t=f_K(\bold{k}_{:t})
V~t=fV(v:t)\tilde{V}_t=f_V(\bold{v}_{:t})

原文公式是像是 V~t=fV(qt,k:t,v:t)\tilde{V}_t=f_V(\bold{q}_t,\bold{k}_{:t},\bold{v}_{:t}) 这样的形式,但实际上并没有用上 qt\bold{q}_tv:t\bold{v}_{:t},可能是笔误。

定义三种映射方法 C={cmp, slc, win}C=\{\text{cmp, slc, win}\}(分别对应压缩、选择和滑动窗口,后文会详细说明),构造出三种 K~tc\tilde{K}^c_t V~tc\tilde{V}^c_t。新的注意力输出会是这三种映射方法获得的注意力输出的加权和:

ot=cCgtcAttn(qt,K~tc,V~tc)\bold{o}^*_t=\sum_{c\in C}g^c_t\cdot \text{Attn}(\bold{q}_t,\tilde{K}^c_t,\tilde{V}^c_t)

其中 gtc[0,1]g^c_t\in [0,1] 是门限值,通过让输入特征经过一个 MLP + Sigmoid 结构获得。

NtN_t 记为三种重映射后总共的 key/value 序列长度。为了达成高度稀疏,需要让重映射总长度远小于原序列长度,即 NttN_t\ll t

三种映射

Token Compression

这一步会把 tokens 分段为块,各块压缩为单个 token。

以 key 矩阵为例。映射 fKcmpf^{\text{cmp}}_K 定义如下:

K~tcmp=fKcmp(k:t)={φ(kid+1:id+l)1itld}\tilde{K}^{\text{cmp}}_t =f^{\text{cmp}}_K(k_{:t}) =\left\{ \varphi(\bold{k}_{id+1:id+l}) \bigg | 1\le i\le \left\lfloor \frac{t-l}{d}\right\rfloor \right\}
  • φ\varphi 是 MLP,将块内 tokens 压缩为单个 token
    • 具有块内位置编码
  • ll 是块长度
  • dd 是滑动窗口步长
    • 通常 d<ld<l,让相邻块重合

Token Selection

以块为单位找出重要的 tokens,只对这些 tokens 进行注意力。

为了减少找块的计算量,重复利用 Compression 步骤结果 K~tcmp\tilde{K}^{\text{cmp}}_t

若 Selection 步骤的块长度 ll 和滑动步长 dd 与 Compression 步骤的一样,则可以直接通过这个式子获得 q 与 k 间的注意力分数:

ptslc=ptcmp=Softmax(qtTK~tcmp)\bold{p}^{\text{slc}}_t =\bold{p}^{\text{cmp}}_t =\text{Softmax}(\bold{q}^T_t\tilde{K}^{\text{cmp}}_t)

论文公式 (9) 提到在 lll'\neq l 的情况该怎么换算。没看明白,感觉条件很苛刻,似乎还额外需要 l=dl=dll'dd 的整倍数。忽略。

对于 HH 个头的多头注意力,还要累加一下分数:

ptslc=h=1Hptslc,(h)\bold{p}^{\text{slc}'}_t =\sum^H_{h=1}\bold{p}^{\text{slc},(h)}_t

获得 ptslc\bold{p}^{\text{slc}'}_t 后,选取分数最高的 nn 个分块,从 k:t\bold{k}_{:t} 中取出并拼接,作为 K~tslc\tilde{K}^{\text{slc}}_t

Sliding Window

K~twin=ktw:t\tilde{K}^{\text{win}}_t=\bold{k}_{t-w:t},其中 ww 为窗口大小。

可见是往左取的窗口。

超参

一些论文中提到的训练超参。以下是论文在验证结果时用到的 NSA 参数。

参数
ll 压缩块长度32
ll' 选择块长度64
dd 滑动步长16
ww 滑动窗口大小512

其他超参。

参数
layers30
hidden dimension2560
heads64
GQA groups4
routed experts72
shared experts2
experts top-k6
dq,dk,dvd_q, d_k, d_v192, 192, 128

碎碎念

DeepSeek 才算有 SOTA 团队该有的的样子。保持活力尝试新东西,勤于分享新成果新发现。这次的 NSA 看起来非常诱人,要是上游的预训练模型慢慢都开始采用 NSA 就太好了。

也许是因为团队在赶时间,我读的这一版 arxiv 论文写得有些粗糙。就像是没删掉大语言模型的废话,或是说从头写到尾后没有根据后文修改对应的前文。不过内容是很真诚的,实现方式实验结果样样不少。

继续加油哇。我好像能看到 AGI 的尾巴了(