大模型加速-核心网络算子-Flash Attention V3

1,504 阅读13分钟

Attention作为 Transformer 架构的核心层,一直是大型语言模型和长上下文应用的瓶颈。

FlashAttention V1 和 FlashAttention V2 提出了一种通过减少内存读写来加速 GPU 上Attention计算的方法。然而,它尚未充分利用最近硬件提供的新能力,比如 FlashAttention-2 在 H100 GPU 上的利用率仅为 35%。Flash Attention V3它贡献了三个新的想法,以进一步提高在最新 GPU 架构上的性能:

  1. 利用TensorCore和 TMA 的异步性,通过专门的 WarpGroup来实现数据搬运和计算的并行(pingpong),以及矩阵块乘法与 softmax 操作在 K,V维度上的进一步并行和数据拆分, 来达到整体计算和数据移动二者的时间重叠

  2. 生产者-消费者异步性:定义了一种基于异步执行数据搬运和TensorCore的 warp 级别软件流水线方案,通过将数据的生产者和消费者分别划分到不同的WarpGroup中,以此扩展算法隐藏内存和指令发射延迟的能力

  3. 引入 FP8 低精度量化(需要 H100 支持),利用硬件对 FP8 低精度的支持进行块量化和无相关处理

FlashAttention-3 基于 H100 硬件新特性设计的优化方案,在 H100 GPU 上FP16精度实现 1.5-2.0 倍的加速,达到 740 TFLOPs/s(利用率 75%),FP8精度 达到接近 1.2 PFLOPs/s。通过实验验证,FP8 FlashAttention-3 的数值误差比基线 FP8 注意力低 2.6倍

下面我们基于以下几方面说明相关背景知识以及Flash Attention V3 的性能优化方案

  • Flash Attention V3 使用新特性说明
  • Flash V3如何利用硬件相关新特性完成性能提升
  • 基于源码级别阐述具体实现(后续补充)

因为 Flash Attention v3 是针对 H100 的新特性进行的性能提升,需要针对h100相关新特性进行了解,才能深刻理解 V3。

1. Flash Attention V3 使用的新特性说明

1.1 TMA 张量记忆加速器(硬件实现)

从功能上来说,其实就是矩阵加载器,本质上是通过专用数据搬运IP(硬件单独功能模块) 和 计算单元(例如英伟达的 SM 流处理器) 利用多个专用的数据传输通道 完成规模比较大的数据块的数据读取和写入, 在H100 TMA 出现之前,GPU 是通过指令完成数据加载的,而非特殊的硬件单元。目前,H100上, TMA和指令搬运这两种方式都可以完成数据加载,下面针对这两种数据读取和写入方式进行说明,本节包含以下两个方面的内容:

  • TMA - 基于硬件的专用数据异步搬运单元
  • 专用指令完成数据搬运

1.1.1 第一种:TMA 基于硬件的专用数据异步搬运单元

TMA介绍

如下图右侧 TMA 是 SM 的一个硬件功能模块实现, 对于矩形数据搬运任务,TMA 会基于任务配置自动生成地址进行数据搬运, 配置可以参考 blockbaseAddress+(block_height,block_width)的类似形式。 image.png

下图是2D(矩形)数据搬运示意图, TMA需要配置矩阵的起始地址和矩阵的宽和高,因此矩阵数据可以表达为:blockbaseAddress+(block_height,block_width) ,注意blockbaseAddress 可以由相对 Tensor 的起始位置的偏移索引决定。 image.png

上述模式的2D 数据块搬运已经是国内神经网络专用 AI 加速器的标配了,一方面是由于 cache 硬件和算法落地的效果距离英伟达还有一定的差距,另外一方面是目前编译器 处于起步阶段,这个阶段算子的 kernel函数的优化会使用大量的 ptx(GPU高级汇编)指令。

TMA的异步数据搬运

我们基于Attention算子当中的第一个Query block和第一个Key block的attention分数的计算过程,即S00=Q0K0TS_{00}=Q_0*K_0^T来说明一下为什么TMA的异步数据搬运会对 Flash Attention 并行加速。我们首先看默认的矩阵分块加速计算示意图。

image.png 其中S00=Q00K00+Q01K10+Q02K20+Q03K30 S_{00} = Q_{00}*K_{00}+ Q_{01}*K_{10}+Q_{02}*K_{20}+ Q_{03}*K_{30}

这里每一个Qij,KijQ_{ij},K_{ij} 分别对应 HBM 上 Q 和 K 两个超大张量内部的一个矩阵块。因此S00S_{00}的计算,需要循环从L3上搬运Qij,KijQ_{ij},K_{ij}, 并进行GEMM计算。假设S00S_{00}由1个warp负责计算,那么该warp的数据搬运和计算顺序如下:

加载(Q00,K00Q_{00},K_{00}) -> 计算Q00K00Q_{00}*K_{00} -> 加载(Q01,K10Q_{01},K_{10}) -> 计算Q01K10Q_{01}*K_{10}....

由于 H100 之前的 GPU 最小调度单元是warp, 其调度执行是采用分时复用的方式进行执行的,kenel内部的数据传输和计算并不能并行起来,上述数据搬运和计算是依序执行的,即如下:

image.png

基于 异步 TMA 可以将上图数据加载数据存储计算异步并行执行,如下图所示: image.png

image.png 基于上图数据加载计算异步并行执行的前提,我们便可以进一步提升性能,将计算的时间和数据搬运进行时间重叠,即进行第一次计算时,同时进行下一次数据计算依赖的数据加载。

综上,TMA 可以让数据加载和数据计算异步并行执行,这本身是国内 AI 加速器针对自己以前底层调度为非分时复用的解决方案,目前由于大模型的大火,为了更好的进行 Pipeline,H100的终于借鉴2D 矩形的异步搬运。

H100两种数据搬运调度方式总结如下:

  • 基于 TMA 的异步调度
  • 分时复用调度(H100 之前的 GPU 仅仅支持的调度方式)

1.1.2 第二种: 基于专用指令的数据搬运(H100 之前的数据搬运方式)

下面是搬运指令的说明,指令详情请参考附录清单 image.png

1.2 pingpong

这个概念最早来自于硬件底层的功能模块中外部接口的设计,其目的是为了提高数据传输的效率, 基于A,B双 buffer 管理和存储数据。

image.png

pingpong 即通过申请所需计算输入2 倍的共享内存存储空间,在第一次数据加载完成开始计算时,异步执行下一次计算所依赖的数据搬运,实现当前计算和下一次数据搬运的时间重叠,流程如下:

image.png 因为第一次计算时,第二次计算所依赖的数据已经开始异步搬运,因此需要申请 2 倍的存储空间。

1.3 WGMMA 指令

WGMMA(Warp-Level General Matrix-Matrix Multiplication Acceleration) NVIDIA 在 Hopper 架构(如 H100 GPU)中引入的新一代矩阵乘法加速指令或 API,可以异步执行。它旨在进一步提升矩阵计算的性能,特别是针对深度学习和高性能计算领域的大规模矩阵运算,该指令具备以下特点:

  • Warp 级别的矩阵运算:WGMMA 允许在 Warp 层面上执行矩阵乘法操作,提高了并行度和计算效率。

  • 更大的数据吞吐量:相比前一代的 WMMA(Warp Matrix Multiply-Accumulate),WGMMA 提供了更高的计算吞吐量,充分利用了 H100 上改进的 Tensor Core 性能。

  • 广泛的数据类型支持:支持多种数据类型,包括 FP8、FP16、BF16、TF32 和 INT8,满足不同应用对精度和性能的需求。

  • 优化的内存访问:结合 TMA(Tensor Memory Accelerator),WGMMA 能够更高效地加载和存储数据,减少内存瓶颈。

Attention算法中,内循环(主循环)内的操作具有序列依赖性,这阻碍了单次迭代内的并行化,例如局部 softmax 依赖第一个 GEMM 的输出 SjiS^i_{j}

上述依赖问题可以通过多分配寄存器,来实现1D和2D的pipeline并行,需要用到WGMMA异步计算指令,在执行当前GEMM时,进行下一次softmax计算。

2.4 异步事务屏障

NVIDIA Hopper 的新功能是 waiting 线程能够在所有其他线程到达之前休眠。在以前的芯片上,等待线程会在共享内存中的屏障对象上旋转

尽管异步屏障仍然是NVIDIA 的漏斗编程模型的一部分,但它增加了一种新的屏障形式,称为异步交易屏障。异步事务屏障类似于异步屏障(图 17 ,右图)。它也是一个分割屏障,但它不仅计算线程到达,还计算事务。

NVIDIA Hopper 包含一个用于写入共享内存的新命令,该命令传递要写入的数据和事务计数。事务计数本质上是字节计数。异步事务屏障在Wait命令处阻塞线程,直到所有生产者线程执行Arrive,并且所有事务计数之和达到预期值。

异步事务屏障是用于异步 mem 拷贝或数据交换的强大新原语。如前所述,集群可以通过隐含的同步进行线程块到线程块的数据交换,集群的能力建立在异步事务壁垒之上

image.png

异步屏障最初引入NVIDIA 安培架构,考虑一组线程产生数据,它们在一个屏障之后都消耗的数据。

1.5 低精度FP8

H100 GPU 增加了 FP8 张量核,以加速人工智能训练和推理,FP8 的 2D 算力高达 4P,支持两种 FP8 的输入类型:

  • E4M3 具有 4 个指数位、 3 个尾数位和 1 个符号位的 ,支持需要更少动态范围和更高精度的计算
  • E5M2 具有 5 个指数位、 2 个尾数位和 1 个符号位, 提供更宽的动态范围和更低的精度

与 FP16 或 BF16 相比, FP8 将数据存储需求减半,吞吐量翻倍。

image.png

2.Flash Attention V3

FlashAttention V3 是针对 H100(Hoper架构) 硬件进行定制化设计的,因此对其他型号 GPU,并不具有普适性,这些优化对于国内的 AI 加速器厂商现阶段不具备借鉴意义。对 TPU DSA 拥有一部 DMA 的架构,Flash Attention实现天然就是 V3 甚至超越 V3的实现,这里我们基于 Flash Attention v3 的视角分析其三大方面的重大升级。

2.1 warp组间Producer-Consumer异步

生产者-消费者异步性:我们定义了一种基于异步执行数据搬运和TensorCore的warp专门化软件流水线方案,通过将数据的生产者和消费者分别划分到不同的warp中,以此扩展算法 隐藏内存和指令发射延迟的能力。下图当中两个 warp 组的 pingpong 调度以重叠 softmax 和 GEMMs:当另一个 warp 组的 GEMMs 运行时,一个 warp 组的 softmax 应该被调度。相同的颜色表示相同的迭代。 image.png

GPU 是吞吐量处理器,依赖并发性和异步性来隐藏内存和执行延迟。对于 GMEM 和 SMEM 之间的异步内存复制,Hopper 具有张量内存加速器(TMA)作为专用硬件单元。此外,与先前的架构如 Ampere 不同,Hopper 的 TensorCore,通过全局工作组范围的 WGMMA 指令暴露出来,也是异步的,并可以直接从共享内存获取输入,具体原理可以阅读第2部分。

2.2 warp组内softmax与GEMM的pipeline重叠

Flash Attention V3 将 softmax 中相对低吞吐量的非 GEMM 操作,如浮点数乘加和指数运算,与异步 WGMMA 指令进行重叠,理解可以参考 pipline, 同一 warp组内利用了WGMMA异步 api来实现GEMM和softmax的时间重叠,如下所示:

image.png

2.3 Hardware-accelerated low-precision GEMM

Flash Attention V3 将前向传递算法调整为针对 FP8 张量核心进行 GEMM 操作,TFLOPs/s 提高了两倍。

2.4 核心算法细节

算法 1:前向传递不包含内部消费者重叠: 这个算法主要通过分阶段加载和并行计算来优化注意力机制的前向计算。具体优化手段包括:

  • 共享存储器缓冲区的使用来减少内存瓶颈。
  • 分配寄存器资源以提高计算性能。
  • 使用逐步累积的方法来避免数值溢出问题。

这个流程主要适用于需要在 GPU 上进行高效注意力计算的深度学习模型,通过流水线和并行机制来减少计算延迟和内存带宽需求。

image.png 算法2:消费者 warp 组前向传递 其核心 warp 内的时间重叠,可以参考上面涉及的内容,这里不重复说明

image.png

2.5 v3与v2性能对比

Flash Attention V3 利用H100 新特性 针对 FlashAttention 的性能进一步提升, 为了更好的了解 Flash Attention V3,建议大家观看参考资料当中英伟达关于 H100 架构的介绍说明,下面是 V3 与 v2 版本的性能对比

image.png

3. 总结

LLM inference(或称为decoding)是一个迭代的过程(自回归,多次前向很多序列是重复的):预测的tokens是逐个生成的。如果生成的句子有N个单词,那么模型需要进行N次forward。一个常用的优化技巧是KV Cache,该方法缓存了之前forward的一些中间结果,节约了大部分运算(如MatMul),但是attention操作是个例外。随着输出tokens长度增加,attention操作的复杂度也急剧上升,Flash Attention V1 V2 分别针对标准 Attention 进行性能优化。

FlashAttention V3的贡献总结如下:

  1. 利用TensorCore和 TMA 的异步性,通过专门的 Warp来实现数据搬运和计算的并行
  2. 生产者-消费者异步,定义了一种基于异步执行数据搬运和TensorCore的 warp 级别软件流水线方案
  3. 引入 FP8 低精度量化(需要 H100 支持),利用硬件对 FP8 低精度的支持进行块量化和无相关处理

值得强调,FlashAttention是针对H100(Hopper架构)制定的特例优化方案,非Hopper架构并不完全适用。

Flash Attention V2 的整体分析请参考: 大模型加速-核心网络算子-Flash Attention V2

参考资料:

developer.nvidia.com/zh-cn/blog/…

zhuanlan.zhihu.com/p/688616037

arxiv.org/pdf/2407.08…