大模型加速-核心网络算子-Flash Attention V1

563 阅读3分钟

标准 Attention 当输入序列(sequence length)较长时,Transformer的计算过程缓慢且耗费内存,这是因为self-attention的 timememory 复杂度会随着sequence length的增加成二次增长

Flash Attention(Fast and Memory Efficient Exact Attention with IO-Awareness) 是针对标准 Attention 计算过程缓慢的问题进行了部分改善,从而提升整体的性能

综上,FlashAttention V1目的不是节约FLOPs,而是通过重复利用高速存储,降低对低速存储HBM的依赖而是减少对HBM的访问次数。重点是FlashAttention在训练和预测过程中的结果和标准Attention一样,对用户是无感的,而其他加速方法做不到这点. Flash Attention另外的一个创新点是基于逆向思维思考,解决了下图红色安全版本softmax局部 fusion的问题 企业微信截图_af047ace-d2c3-4b09-934e-d845039d9279.png 由于本文档是面向新手普及知识的下面我们会基于 以下几方面进行细粒度说明

  • 标准 Attention 及其空间时间复杂度
  • Flash Attention 优化技术方案, 简单说明其核心创新思想的来源和具体的优化方案(备注:关于算法流程中比较难理解的 分块利用局部最大值参与计算 需要针对 历史 逐行累计 Sum输出进行补偿 机制,后面会细粒度说明)

1.标准 Attention

注意力机制是 Transform 当中很重要的网络结构组成,下面我们基于标准 Attention的公式 和计算工程阐述其空间和时间复杂度

image.png

1.1 标准 Attention 的计算

标准注意力机制的核心公式如下

Attention(Q,K,V)=softmax(QKdk)V\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q K^\top}{\sqrt{d_k}} \right) V

其中 QRmdkQ \in R^{m*d_k} , KRmdkK \in R^{m*d_k} , VRmdvV \in R^{m*d_v} 如下所示:

image.png

对于标准Attention 其核心计算分为4 步

企业微信截图_1f6e3b79-39e8-48da-a085-ccad86c58ef4.png

Attention(Q,K,V)=softmax(QKdk)V\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q K^\top}{\sqrt{d_k}} \right) V

其核心链路涉及

  • 第一步计算 中间结果 S, 即 S=QKTS=QK^T 写回到 HBM
  • 第二部计算 中间结果 P, 即 P=softmax(S)P=softmax(S) 写回到 HBM
  • 第三步计算 最终输出结果O, 即 O=PVO=PV 写回到 HBM
  • 第四步 返回结算结果

标准Attention的中间结果S,P 通常需要通过高带宽内存(HBM)进行存取,两者所需内存空间复杂度为O(N2)O(N^2)

1.1.1 时间复杂度说明

时间复杂度为 O(N2)O(N^2) 说明,在自注意力机制中,模型需要计算输入序列中每个词与其他所有词的相关性。这意味着:

  • 每个词(位置 i) 需要与 所有词(位置 j) 计算注意力得分。
  • 输入为 Token 长度为 N 的话,总共需要计算 N×N=N2N×N=N^2 次,即时间复杂度为 O(N2)O(N^2) ,下面为一个输入序列长度为4 的例子。
输入序列长度 N = 4

[词1] ←→ [词2] ←→ [词3] ←→ [词4]
   ↑        ↑        ↑        ↑
   ↓        ↓        ↓        ↓
[词1] ←→ [词2] ←→ [词3] ←→ [词4]

(每个词与所有词相互计算注意力)

1.1.2 空间复杂度说明

空间复杂度为 O(N2)O(N^2)说明,例如长度为 N 的序列, 每一个 Token 生成的注意力矩阵输出为 N,整体 N 个Token 需要的注意力权重空间如下

  • 计算过程中生成的注意力得分矩阵尺寸为 N×NN×N,需要存储在内存中。

  • 当 N 较大时,矩阵尺寸NNN*N 会急剧增大,即空间复杂度为 O(N2)O(N^2)

2.Flash Attention 优化技术方案

Flash Attention V1 针对标准 Attention 进行了降低访问 HBM 次数的优化,计算逻辑没有修改,因此计算结果与标准的 Attention 保持一致,性能提升了很多

GPU中存储单元主要有HBM和SRAM: HBM容量大但是访问速度慢,SRAM容量小却有着较高的访问速度。例如:A100 GPU有40-80GB的HBM,带宽为1.5-2.0TB/s;每108个流式多核处理器各有192KB的片上SRAM,带宽估计约为19TB/s

image.png Flash Attention V1 核心通过在局部(片上 SRAM)计算块状分割的数据,有效减少了对全局内存(HBM)的访问需求。通过优化数据传输和利用片上高速缓存,算法降低了内存带宽的需求并提高了计算效率

2.1 核心算法介绍

新的注意力算法 FlashAttention,它能够在远少于内存访问次数的情况下计算精确的注意力。主要目标是避免读取和写入注意力矩阵到HBM。这需要解决以下两个挑战:

  • (i) 在无法访问整个输入的情况下计算 softmax 减少;
  • (ii) 不存储大型中间注意力矩阵以供反向传播使用。

我们应用了两个已建立的技术来解决这些挑战。

  • (i) 重新结构注意力计算,将输入分为块,并对输入块进行多次遍历,从而逐步执行 softmax 规约(也称为 titling)。
  • (ii) 在前向传播中存储 softmax 归一化因子,以便在反向传播中快速在芯片上重新计算注意力,这比标准方法从 HBM 读取中间注意力矩阵更快。

基于 CUDA 实现 FlashAttention,以实现对内存访问的精细控制,并将所有注意力操作融合到一个 GPU 内核中。即使由于重新计算导致的 FLOPs 增加,Flash Attention V1算法仍然运行得更快(高达 7 倍)。

核心算法如下:

image.png

V1 应用了两种已确立的技术(titling(分块),Recomputation(重计算))来克服在次二次 HBM 访问中计算精确注意力的技术挑战。我们通过算法 1 描述了这一点。主要思想是,我们将输入 𝐐,𝐊,𝐕 分为块,从慢速 HBM 加载到快速 SRAM,然后根据这些块计算注意力输出。通过在将每个块的输出与正确的归一化因子相乘之前将它们相加,最终得到正确的结果

Require: 表示输入序列的查询(Q)、键(K)、值(V)矩阵,N是序列长度,d是嵌入维度。HBM指的是高带宽内存,SRAM指的是片上静态随机存储器,注意这里的 Q,K,V 是由 WQW_Q,WKW_K,WVW_V 线性变化得到的,因此大小与输入序列 N 相关

    1. 设置 Set block sizes:
    • Bc = ⌈M / 4d⌉:设置列维度 块大小Bc, 行维度大小 BrBr = min(⌊M / 4d⌋, d) 表示K和V矩阵分块的大, 使其适应硬件内存 M

    • Initialize O: 初始化输出矩阵 O 和对应的中间变量 , m 这里依赖 nv 关于 softmax 并行化的一个探索,请参考附件

      • m 是用于防止数值溢出的中间变量,通常在计算 exp 函数时引入,以确保计算的数值稳定性。

      • 在自注意力计算中,m 保存了当前计算中每个块的最大值,用于在计算 exp(Sijmij) exp(S_{ij} - m_{ij}) 时减去最大值,防止出现数值过大的问题,从而避免指数运算导致的溢出。

      • 在算法中,m 的更新公式为:mnew=max(mi,m~ij)m_{new} = max(m_i, m̃_{ij}),这里 m~ijm̃_{ij}SijS_{ij}的行最大值

    1. 划分 Q,K,V 矩阵
    • 将 Q 划分为 Tr=N/BrT_r = ⌈N / B_r⌉ 个块,每块大小为 Br×dB_r × d
    • 将 K 和 V划分为 Tc=N/BcT_c = ⌈N / B_c⌉ 个块,每块大小为 Bc×dB_c × d
    1. 划分O矩阵
    • 将输出矩阵 O 划分为 TrT_r 个块,每块大小为 Br×dB_r × d
    • 同样地,将中间变量 m 划分为 TrT_r 个块。
  • 4.外层循环(第5行到第14行):

    • 遍历 KV 的块 (TcT_c 个块)。
    • 从高带宽内存(HBM)中加载当前的 KjK_jVjV_j 到片上存储器(SRAM)
  • 5.内层循环(第7行到第13行):

    • 遍历 Q 的块 (TrT_r 个块)。
    • HBM 中加载当前的 QiQ_i, OiO_i, iℓ_i, mim_iSRAM
  • 6.计算注意力得分(第9行):

    • 在片上计算 Sij=QiKjTS_{ij} = Q_i K_j^T,生成大小为 Br×BcB_r × B_c 的得分矩阵。
    • 计算 Pij=exp(Sijmij)P_ij = exp(S_ij - m_ij) 以防止数值溢出,并对每行进行归一化。
  • 7.更新中间变量(第10行到第11行):

    • 计算每行的最大值 m~ijm̃_ij 以及行和 iℓ_i
    • 更新 OiO_iiℓ_i,将它们写回到 HBM。这里每次只更新具体值 注意: 该步骤关于分块局部求和L的补偿 和 输出 O 的归一化操作,后面会重点说明
  • 8.输出结果

    • 返回最终计算得到的输出矩阵 O

示例说明: 假设有一个输入序列,长度 N = 8,嵌入维度 d = 4,硬件内存大小 M 使得 B_c = 2B_r = 2

企业微信截图_c44e313f-f789-4d87-8978-a6e3dcbcc8f7.png 图一:逐行分块计算注意力

  • 划分Q、K、V

    • Q 被划分为 T_r = ⌈8 / 2⌉ = 4 块,每块大小 2 × 4
    • KV 被划分为 T_c = ⌈8 / 2⌉ = 4 块,每块大小 2 × 4
  • 外层循环:总共有 T_c = 4 个块,表示需要遍历 KV 的 4 个块。

    • 第一次迭代加载 K_1V_1SRAM
  • 内层循环:总共有 T_r = 4 个块,表示需要遍历 Q 的 4 个块。

    • 第一次迭代加载 Q_1, O_1, L_1, m_1SRAM
    • 在片上计算 S_11 = Q_1 K_1^T,得到一个 2 × 2 的矩阵。( 备注:内层循环 4 次得到 2*8)
    • 计算 P11=exp(S11m11)P_{11} = exp(S_{11} - m_{11}) 以防止数值溢出,并对每行进行归一化
    • 然后,进行归一化、更新 O_1L_1
  • 更新

    • 重复上述步骤直到遍历完所有的 KV 块,逐步完成 O 的计算。

2.2 分块计算max最大值补偿

分块计算,导致局部不能或得每行的最大值 mim_i ,因此在算法设计上需要针对局部能够或得 整行的最大值进行计算补偿,主要分为两方面

    1. 历史局部逐行累计和 L 的补偿
    1. 历史OiO_i 的补偿

2.2.1 历史局部逐行累计和 L 的补偿

首先我们看softmax

softmax(x)=exjmij=0Nexjmisoftmax(x) = \frac{e^{x_j-m_i}}{\sum_{j=0}^N {e^{x_j - m_i}}}

L 本质是逐行分块sum, 的具体公式如下

Linew=emijminewLi +em~ijminewL~ijL^{new}_{i}=e^{m_{ij}-m_i^{new}} *L_i  + e^{\tilde{m}_{ij} - m_i^{new}} * \tilde{L}_{ij}

i=0 j=0m~00=m0new\tilde{m}_{00} = m_0^{new}, 历史L0=0L_0=0上述公式可以简化为(备注:为了不歧义,求和的 j 替换成 r)

L0new=em~00m0newL~00=L00~=r=0N0exrmi L^{new}_{0}= e^{\tilde{m}_{00} - m_0^{new}} * \tilde{L}_{00} = \tilde{L_{00}} = \sum_{r=0}^{N_0} {e^{x_r - m_i}}

i=0 j=1 时, L0new=L0+L01 L^{new}_{0} = L_0+L_{01} 为了让历史的 L0L_0 (即上一次的L0newL_0^{new})计算准确,我们需要针对 上一次计算进行数学处理乘以特定系数 W,以保证最新的 L0newL_0^{new} 计算是无误的

L0new=WL0+L01=r=0N0exrm00 L^{new}_{0} = W*L_0+L_{01} = \sum_{r=0}^{N_0} {e^{x_r - m_{00}}} 这里的m00m_{00}即历史的最大 m0m_0, 实时最新的为 m0newm_0^{new}

基于指数运算的本质 exrm00=exrm0e^{x_r - m_{00}}=e^{x_r - m_{0}} 转换成 exrm0new{e^{x_r - m_{0}^{new}}},因此只需要 e(xrm0)+(m0m0new)=em0m0newexrm0e^{(x_r - m_{0})+(m_{0}-m_{0}^{new})} = e^{m_{0}-m_{0}^{new}} * e^{x_r - m_{0}}

即 系数为 W=em0m0new W = e^{m_{0}-m_{0}^{new}} 依此类推,因此,每次只要每行新块计算时乘以 该系数,则不会影响最新块的逐行sum结果,即

Linew=emijminewLi +em~ijminewL~ijL^{new}_{i}=e^{m_{ij}-m_i^{new}} *L_i  + e^{\tilde{m}_{ij} - m_i^{new}} * \tilde{L}_{ij}

O_i 的每次合并也是同样需要针对局部最大值补偿

2.2.2 历史OiO_i 的补偿

我先看 OiO_i的计算公式

Oi=diag(Linew)1 (diag(Li)emiminewOi+ e m~ijminew P~ijVj  )O_{i} = diag(L_i ^{new})^{-1}  * (diag(L_i)*e^{m_{i} - m_i^{new}}*O_{i}+   e^{ \tilde{ m}_{ij} - m_i^{new}} *  \tilde{P}_{ij}*V_j   )

为了易于理解我们先介绍一下diag(Linew)1diag(L_i ^{new})^{-1}的功能作用,以方便后续去除简化公式 diag(Linew)1diag(L_i ^{new})^{-1} 这里表示对角线全部是 1Linew\frac{1}{L_i ^{new}} 的矩阵,核心功能是对右侧的矩阵进行缩放

[1Linew001Linew]\begin{bmatrix} &\frac{1}{L_i ^{new}} &0 \\ &0 &\frac{1}{L_i ^{new}} \\ \end{bmatrix}

下面我们通过一个例子理解基于对角矩阵的缩放

[N00M][abcd]=[NaNbMcMd]\begin{bmatrix} &N &0 \\ &0 &M \\ \end{bmatrix} * \begin{bmatrix} &a &b \\ &c &d \\ \end{bmatrix} = \begin{bmatrix} & Na &Nb \\ &Mc &Md \\ \end{bmatrix}

所以 diag(Linew)1diag(L_i ^{new})^{-1} 核心 为对结果进行归一化,即结果浮点数控制在 1 以内,这里为了方便理解我们移除掉,原有公式简化为:

Oi=diag(Li)emiminewOi+ e m~ijminew P~ijVj O_{i} = diag(L_i)*e^{m_{i} - m_i^{new}}*O_{i}+   e^{ \tilde{ m}_{ij} - m_i^{new}} *  \tilde{P}_{ij}*V_j 

参考上文当中的图一:逐行分块计算注意力 我们可以看到 分块 计算的OiO_i 都是 softmax(P_ij)*Vj 局部求和的一部分,下面是示意图(仅仅示意非全部操作)

企业微信截图_d5b0a510-5257-4e02-a77f-b8904a71a1fa.png

i=0,j=0 时,历史m_0值为 m0=m0new=m00m_{0} = m_0^{new} = m_{00}O0=0O_0=0上文公式可以简化为

O0= e m~00m0new P~00V0= P~00V0=eS00m~00V0O_{0} =   e^{ \tilde{ m}_{00} - m_0^{new}} *  \tilde{P}_{00}*V_0 = \tilde{P}_{00}*V_0 = e^{S_{00}-\tilde{m}_{00}} * V_{0}

对于 i=0,j=1 的时候

O0=diag(L0)em0m0newO0+ e m~01m0new P~01V0 =B0补归一+B1O_{0} = diag(L_0)*e^{m_{0} - m_0^{new}}*O_{0}+   e^{ \tilde{ m}_{01} - m_0^{new}} *  \tilde{P}_{01}*V_0  =B_{0补归一}+B_1

即 由上一个块计算完的 B0B_{0} 补偿 和归一化 后 和 最新块计算的结果求和得到最新的 O0O_0

对于前半部分 diag(L0) em~00m0newO0 diag(L_0) *  e^{\tilde{ m}_{00}-m_0^{new'}}* O_0 , 带入之前的O0=eS00m~00V0O_{0} =e^{S_{00}-\tilde{m}_{00}} * V_{0} 等价于

diag(L0)em0m0neweS00m~00V0=diag(L0)eS00m0newdiag(L_0)*e^{m_{0} - m_0^{new}}* e^{S_{00}-\tilde{m}_{00}} * V_{0} =diag(L_0)*e^{S_{00}-m_0^{new}}

即对上一次的 P00P_{00} 进行补偿 局部 max 值计算导致的误差问题, 前半部分 统一乘以 diag(L0)diag(L_0) 是为了还原之前的计算,最后重新基于 (diag(L0)new)1(diag(L_0)^new)^{-1} 重新归一化

2.2.3 补偿的一些小问题探讨

上述所有难理解的公式全部是由于补偿局部 max 和 重新归一化 norm 进行的数学操作,对于下图当中的操作由于外层循环为 KiK_i 所以每次的 每行最新的 m 需要 L 都需要读写 HBM (依据算法实现流程) 企业微信截图_c44e313f-f789-4d87-8978-a6e3dcbcc8f7.png

如果合理分配任务且外层循环 Q,内层 key 则可以避免频繁的 m 和 L 的存储读取操作, 对于缩放操作由于Q 修改成外循环,则上述的缩放操作可以最后进行,节省 1D 算力,以上这些都是可以继续优化的

3.总结

FlashAttention 从 128 到 2K 的常见序列长度范围内,与标准注意力实现相比,快至 30 倍。并且可以扩展到 64K。在序列长度为 512 的情况下,FlashAttention 不仅在速度上更快,而且在内存效率上也优于任何现有的注意力方法。而对于超过 1K 的序列长度,一些近似注意力方法(例如 Linformer)开始变得更快。另一方面,块稀疏 FlashAttention 比我们所知的所有现有近似注意力方法都要快

综上整体文档归纳总结,FlashAttention V1目的不是节约FLOPs,而是减少对HBM的访问,将复杂度由 n2 n^2 降为了n2n^2, 重点是FlashAttention在训练和预测过程中的结果和标准Attention一样,对用户是无感的,而其他加速方法做不到这点

参考文献: