大模型加速-核心网络算子-Flash Attention V2

364 阅读3分钟

Flash Attention V2该论文旨在解决在扩展Transformer模型以处理更长序列时,由于注意力层的运行时间和内存消耗呈二次增长,导致性能下降的问题

FlashAttention V1通过利用GPU内存层次结构的不对称性,实现了线性内存节省和2-4倍的运行时加速,同时没有使用近似方法。然而,FlashAttention V1 不足之处是仍然比优化后的矩阵乘法操作慢得多,只能达到理论最大FLOPs/s的25-40%,具体可以参考同系列文章大模型加速-核心网络算子-Flash Attention V1

FlashAttention-2通过优化工作分区,解决了FlashAttention的低效率问题,其性能提升相对于FlashAttention V1的速度提高了约2倍,在A100上达到理论最大FLOPs/s的50-73%,接近于GEMM操作的效率。作者通过实验证明,当用于端到端训练GPT-style模型时,FlashAttention-2的训练速度可达每个A100 GPU的225 TFLOPs/s(72%模型FLOPs利用率), 其具体的优化方案为:

  • (1) 调整算法以减少非矩阵乘法FLOPs数量
  • (2) 并行计算注意力,即使是单个头部,也跨不同的线程块以增加占用率
  • (3) 每个线程块内,将工作分配给线程束以减少通过共享内存的通信

下面我们从以下几方面说明Flash Attention V2是如何在 V1 的基础上改进加速的

  • 标准 Attention 简洁说明
  • Flash Attention V1 解决方案和劣势不足
  • Flash Attention V2 具体算法

1.标准 Attention

Attention(Q,K,V)=softmax(QKdk)V\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q K^\top}{\sqrt{d_k}} \right) V

对于标准Attention 其核心计算分为4 步

企业微信截图_1f6e3b79-39e8-48da-a085-ccad86c58ef4.png

其核心链路涉及

  • 第一步计算 中间结果 S, 即 S=QKTS=QK^T 写回到 HBM
  • 第二部计算 中间结果 P, 即 P=softmax(S)P=softmax(S) 写回到 HBM
  • 第三步计算 最终输出结果O, 即 O=PVO=PV 写回到 HBM
  • 第四步 返回结算结果

标准Attention的中间结果S,P 通常需要通过高带宽内存(HBM)进行存取,两者所需内存空间复杂度为O(N2)O(N^2)

2.Flash Attention V1

针对标准 Attention 4 部当中的高 HBM 访存操作,Flash Attention V1 利用基于共享缓存的数据分块,局部 max 计算和前置计算局部max 补偿方案 降低 HBM 的访问次数进行提速

2.1 Flash Attention 核心算法如下:

image.png

V1 应用了两种已确立的技术(titling(分块),Recomputation(重计算降低内存存储))来克服在次二次 HBM 访问中计算精确注意力的技术挑战,流程示意图如下 企业微信截图_c44e313f-f789-4d87-8978-a6e3dcbcc8f7.png 图一:逐行分块计算注意力

2.2 Flash Attention 可以提升改进之处:

Flash Attention V1 如果外循环修改为基于 QjQ_j, 则每个warp 可以连续针对 OjO_j 连续处理,可以有效避免中间变量的 HBM 保存和加载,同时针对 O 最后进行归一化缩放,避免局部对 OjO_j 的每次 1D 缩放操作,具体示意参考下图:

企业微信截图_f174fa81-70c8-4f8a-809b-4b4a27364d5a.png

新的算法可以将 每次分块内循环OjO_{j} 的归一化补偿统一最后一次进行 原有公式

Oj=diag(Ljnew)1 (diag(Lj)emjmjnewOj+ e m~ijmjnew P~ijVj  )O_{j} = diag(L_j ^{new})^{-1}  * (diag(L_j)*e^{m_{j} - m_j^{new}}*O_{j}+   e^{ \tilde{ m}_{ij} - m_j^{new}} *  \tilde{P}_{ij}*V_j   )

新的简化可以省略每次归一化的还原系数 diag(Lj)diag(Lj) 和最新累计和的归一化操作 diag(Ljnew)1diag(L_j ^{new})^{-1}

Oj=emjmjnewOj+ e m~ijmjnew P~ijVj  O_{j} = e^{m_{j} - m_j^{new}}*O_{j}+   e^{ \tilde{ m}_{ij} - m_j^{new}} *  \tilde{P}_{ij}*V_j   

修改为最后一次基于最新的diag(Ljnew)1diag(L_j ^{new})^{-1} 统一归一化

另外由于 v1 当中以下公式计算

Oj=emjmjnewOj+ e m~ijmjnew P~ijVj  O_{j} = e^{m_{j} - m_j^{new}}*O_{j}+   e^{ \tilde{ m}_{ij} - m_j^{new}} *  \tilde{P}_{ij}*V_j   

当中的 P 计算由于一个 warp 已经可以完成 S 的一整行计算,因此最新的 P 不再基于局部mijm_ij,而是 mnewm^{new} 计算,因此很多1D 算力计算可以减少

Oj=emjmjnewOj+P~ijVj  O_{j} = e^{m_{j} - m_j^{new}}*O_{j}+ \tilde{P}_{ij}*V_j   

同理 整行累计和 第二个局部快计算时

L0new=em01m0newL0 +em~01m0newL~01L^{new}_{0}=e^{m_{01}-m_0^{new}} *L_0  + e^{\tilde{m}_{01} - m_0^{new}} * \tilde{L}_{01}

基于整行 max 可以简化为一下公式,节省 1D 算力:

L0new=em01m0newL0 +L~01 L^{new}_{0}=e^{m_{01}-m_0^{new}} *L_0  + \tilde{L}_{01}

3.Flash Attention V2 具体算法

V2 调整了 FlashAttention 算法,以减少非矩阵乘运算的 FLOPs 数量。这是因为现代 GPU(例如 Nvidia GPU 上的张量核心)专门用于加速矩阵乘运算。例如,A100 GPU 的最大理论吞吐量为每秒 312 TFLOPs/s 的 FP16/BF16 矩阵乘运算,但非矩阵乘运算的 FP32 仅为 19.5 TFLOPs/s。另一种思考方式是,每个非矩阵乘运算的 FLOPs 比矩阵乘运算的 FLOPs 贵 16 × 。为了保持高吞吐量(例如超过最大理论 TFLOPs/s 的 50%),我们希望尽可能多地使用矩阵乘运算的 FLOPs,为了更好的理解,我们先看一下

V1 的算法计算两次 OiO_i,即Oi(1)O_i^{(1)} Oi(2)O_i^{(2)} 如下

image.png

3.1 Flash Attention V2 1D 算力节省

移除每次分块 Oi(1)Oi(2)O_i^{(1)} O_i^{(2)} 每次的归一化操作,修改为最后一次归一化操作一次,由于行内 max 值可以在一个 warp 内完成计算且保存在集群器或共享缓存中,因此所有依赖局部 max 值的均可修改为基于当前整行 max 的最大值计算,针对v1 版本当前步的系数可以移除

Oj=diag(Ljnew)1 (diag(Lj)emjmjnewOj+ e m~ijmjnew P~ijVj  )O_{j} = diag(L_j ^{new})^{-1}  * (diag(L_j)*e^{m_{j} - m_j^{new}}*O_{j}+   e^{ \tilde{ m}_{ij} - m_j^{new}} *  \tilde{P}_{ij}*V_j   )

即 v1后半部分

 e m~ijmjnew P~ijVj 变更为e m~newmjnew P~ijVj=P~ijVj  e^{ \tilde{ m}_{ij} - m_j^{new}} *  \tilde{P}_{ij}*V_j 变更为 e^{ \tilde{ m}_{new} - m_j^{new}} *  \tilde{P}_{ij}*V_j = \tilde{P}_{ij}*V_j

其他计算同理,v2整体逻辑如下 image.png

整体前向算法如下

image.png

3.2 因果遮蔽

注意力的一个常见应用场景是在自回归语言建模中,需要在注意力矩阵 𝐒 上应用因果掩码(即,对于 𝐒i​j 中的任何条目,如果 j>i ,则将其设置为 −∞ )

  • 于 FlashAttention 和 FlashAttention-2 已经以块的形式运行,对于所有列索引都大于行索引的块(大约一半的块,特别是对于长序列),可以跳过该块的计算。这相对于不使用因果掩码的注意力操作大约可以提高 1.7-1.8 倍的速度
  • 不需要为行索引保证严格小于列索引的块应用因果掩码。这意味着对于每一行,只需要对 1 个块应用因果掩码(假设块是正方形的)

3.3 反向传播优化

FlashAttention-2 的反向传播过程几乎与 FlashAttention 相同。我们进行了一项小的调整,仅使用行内对数求和 L 代替 softmax 中的行内最大值和行内指数求和,具体算法如下:

image.png

3.4 并行处理

Flash Attention V1 并行处理批大小和头部数量。使用 1 个线程块来处理一个注意力头部,整体任务总共需要 batch size*number of heads 个线程块。每个线程块被安排在流多处理器(SM)上运行,例如,A100 GPU 上有 108 个这样的 SM。当这个数字很大(例如 ≥80 )时,这种调度方式是高效的,因为我们几乎可以充分利用 GPU 上的所有计算资源

  • Flash Attention V2 在长序列的情况下(通常意味着小批量或少量头),为了更好地利用 GPU 上的多处理器,通过序列长度维度并行化,显著提高了速度

  • 交换循环顺序(外层循环遍历行块,内层循环遍历列块,而不是原始 FlashAttention 论文中的相反顺序)具体机制原理收益可以参考 2.2 针对 v1 的可改进优化分析

  • 在前向传播(左侧),我们将 Worker(线程block块)并行化,每个worker负责注意力矩阵中的一组行。

image.png

  • block 内 warp之间任务分配优化,具体也可以参考 2.2,下面我们说一下细节

在 FlashAttention-2 中,将 𝐐 分为 4 个warp,同时保持 𝐊 和 𝐕 可被所有warp访问。在每个warp执行矩阵乘法以获取 𝐐𝐊𝐐𝐊^⊤ 的一个切片后,它们只需要与它们共享的 𝐕 的切片相乘,以获取对应输出的切片。warp之间不需要通信。共享内存的读写减少带来了加速

即外循环可以完成完整的一次逐行计算,不需要通信

image.png

4.总结

通过外循环修改为 Q,可以很好的,避免 warp 之间的通信,利用高速缓存,进一步降低 HBM 的访问次数,同时有效的调整归一化补偿算法,节省 1D 算力 官方收益数据如下:

FlashAttention-2 在不同序列长度下的运行时间,并将其与 PyTorch 中的标准实现、FlashAttention 和 Triton 中的 FlashAttention 进行比较。我们确认 FlashAttention-2 比 FlashAttention 快 1.7-3.0 倍 × ,比 Triton 中的 FlashAttention 快 1.3-2.5 倍 × ,比标准注意力实现快 3-10 倍 × 。FlashAttention-2 达到了 230 TFLOPs/s,A100 GPU 上的理论最大 TFLOPs/s 的 73%

端到端训练速度 当使用端到端方法训练大小为 1.3B 和 2.7B 的 GPT 风格模型,序列长度为 2k 或 8k 时,FlashAttention-2 相较于 FlashAttention,提供了高达 1.3 × 的加速比,相较于没有 FlashAttention 的基线,提供了 2.8 × 的加速比。FlashAttention-2 在每个 A100 GPU 上达到最高 225 TFLOPs/s(模型 FLOPs 利用率为 72%)

同系列文章: 大模型加速-核心网络算子-Flash Attention V1

参考资料:arxiv.org/pdf/2307.08…