标准 Attention 当输入序列(sequence length)较长时,Transformer的计算过程缓慢且耗费内存,这是因为self-attention的 time 和 memory 复杂度会随着sequence length的增加成二次增长
Flash Attention(Fast and Memory Efficient Exact Attention with IO-Awareness) 是针对标准 Attention 计算过程缓慢的问题进行了部分改善,从而提升整体的性能
综上,FlashAttention V1目的不是节约FLOPs,而是通过重复利用高速存储,降低对低速存储HBM的依赖而是减少对HBM的访问次数。重点是FlashAttention在训练和预测过程中的结果和标准Attention一样,对用户是无感的,而其他加速方法做不到这点. Flash Attention另外的一个创新点是基于逆向思维思考,解决了下图红色安全版本softmax局部 fusion的问题
由于本文档是面向新手普及知识的下面我们会基于 以下几方面进行细粒度说明
- 标准 Attention 及其空间时间复杂度
- Flash Attention 优化技术方案, 简单说明其核心创新思想的来源和具体的优化方案(备注:关于算法流程中比较难理解的 分块利用局部最大值参与计算 需要针对 历史
逐行累计 Sum和输出进行补偿 机制,后面会细粒度说明)
1.标准 Attention
注意力机制是 Transform 当中很重要的网络结构组成,下面我们基于标准 Attention的公式 和计算工程阐述其空间和时间复杂度
1.1 标准 Attention 的计算
标准注意力机制的核心公式如下
其中 , , 如下所示:
对于标准Attention 其核心计算分为4 步
其核心链路涉及
- 第一步计算 中间结果 S, 即 写回到 HBM
- 第二部计算 中间结果 P, 即 写回到 HBM
- 第三步计算 最终输出结果O, 即 写回到 HBM
- 第四步 返回结算结果
标准Attention的中间结果S,P 通常需要通过高带宽内存(HBM)进行存取,两者所需内存空间复杂度为。
1.1.1 时间复杂度说明
时间复杂度为 说明,在自注意力机制中,模型需要计算输入序列中每个词与其他所有词的相关性。这意味着:
- 每个词(位置 i) 需要与 所有词(位置 j) 计算注意力得分。
- 输入为 Token 长度为 N 的话,总共需要计算 次,即时间复杂度为 ,下面为一个输入序列长度为4 的例子。
输入序列长度 N = 4
[词1] ←→ [词2] ←→ [词3] ←→ [词4]
↑ ↑ ↑ ↑
↓ ↓ ↓ ↓
[词1] ←→ [词2] ←→ [词3] ←→ [词4]
(每个词与所有词相互计算注意力)
1.1.2 空间复杂度说明
空间复杂度为 说明,例如长度为 N 的序列, 每一个 Token 生成的注意力矩阵输出为 N,整体 N 个Token 需要的注意力权重空间如下
-
计算过程中生成的注意力得分矩阵尺寸为 ,需要存储在内存中。
-
当 N 较大时,矩阵尺寸 会急剧增大,即空间复杂度为
2.Flash Attention 优化技术方案
Flash Attention V1 针对标准 Attention 进行了降低访问 HBM 次数的优化,计算逻辑没有修改,因此计算结果与标准的 Attention 保持一致,性能提升了很多
GPU中存储单元主要有HBM和SRAM: HBM容量大但是访问速度慢,SRAM容量小却有着较高的访问速度。例如:A100 GPU有40-80GB的HBM,带宽为1.5-2.0TB/s;每108个流式多核处理器各有192KB的片上SRAM,带宽估计约为19TB/s
Flash Attention V1 核心通过在局部(片上 SRAM)计算块状分割的数据,有效减少了对全局内存(HBM)的访问需求。通过优化数据传输和利用片上高速缓存,算法降低了内存带宽的需求并提高了计算效率
2.1 核心算法介绍
新的注意力算法 FlashAttention,它能够在远少于内存访问次数的情况下计算精确的注意力。主要目标是避免读取和写入注意力矩阵到HBM。这需要解决以下两个挑战:
- (i) 在无法访问整个输入的情况下计算 softmax 减少;
- (ii) 不存储大型中间注意力矩阵以供反向传播使用。
我们应用了两个已建立的技术来解决这些挑战。
- (i) 重新结构注意力计算,将输入分为块,并对输入块进行多次遍历,从而逐步执行 softmax 规约(也称为 titling)。
- (ii) 在前向传播中存储 softmax 归一化因子,以便在反向传播中快速在芯片上重新计算注意力,这比标准方法从 HBM 读取中间注意力矩阵更快。
基于 CUDA 实现 FlashAttention,以实现对内存访问的精细控制,并将所有注意力操作融合到一个 GPU 内核中。即使由于重新计算导致的 FLOPs 增加,Flash Attention V1算法仍然运行得更快(高达 7 倍)。
核心算法如下:
V1 应用了两种已确立的技术(titling(分块),Recomputation(重计算))来克服在次二次 HBM 访问中计算精确注意力的技术挑战。我们通过算法 1 描述了这一点。主要思想是,我们将输入 𝐐,𝐊,𝐕 分为块,从慢速 HBM 加载到快速 SRAM,然后根据这些块计算注意力输出。通过在将每个块的输出与正确的归一化因子相乘之前将它们相加,最终得到正确的结果
Require: 表示输入序列的查询(Q)、键(K)、值(V)矩阵,N是序列长度,d是嵌入维度。HBM指的是高带宽内存,SRAM指的是片上静态随机存储器,注意这里的 Q,K,V 是由 ,, 线性变化得到的,因此大小与输入序列 N 相关
-
- 设置 Set block sizes:
-
Bc = ⌈M / 4d⌉:设置列维度 块大小Bc, 行维度大小Br,Br = min(⌊M / 4d⌋, d)表示K和V矩阵分块的大, 使其适应硬件内存M。 -
Initialize O: 初始化输出矩阵
O和对应的中间变量ℓ,m这里依赖 nv 关于 softmax 并行化的一个探索,请参考附件-
m是用于防止数值溢出的中间变量,通常在计算exp函数时引入,以确保计算的数值稳定性。 -
在自注意力计算中,
m保存了当前计算中每个块的最大值,用于在计算 时减去最大值,防止出现数值过大的问题,从而避免指数运算导致的溢出。 -
在算法中,
m的更新公式为:,这里 是 的行最大值
-
-
- 划分 Q,K,V 矩阵
- 将 Q 划分为 个块,每块大小为
- 将 K 和 V划分为 个块,每块大小为
-
- 划分O矩阵:
- 将输出矩阵
O划分为 个块,每块大小为 。 - 同样地,将中间变量
ℓ和m划分为 个块。
-
4.外层循环(第5行到第14行):
- 遍历
K和V的块 ( 个块)。 - 从高带宽内存(HBM)中加载当前的 和 到片上存储器(SRAM)
- 遍历
-
5.内层循环(第7行到第13行):
- 遍历
Q的块 ( 个块)。 - 从
HBM中加载当前的 , , , 到SRAM。
- 遍历
-
6.计算注意力得分(第9行):
- 在片上计算 ,生成大小为 的得分矩阵。
- 计算 以防止数值溢出,并对每行进行归一化。
-
7.更新中间变量(第10行到第11行):
- 计算每行的最大值 以及行和 。
- 更新 和 ,将它们写回到
HBM。这里每次只更新具体值注意: 该步骤关于分块局部求和L的补偿 和 输出 O 的归一化操作,后面会重点说明
-
8.输出结果:
- 返回最终计算得到的输出矩阵
O。
- 返回最终计算得到的输出矩阵
示例说明:
假设有一个输入序列,长度 N = 8,嵌入维度 d = 4,硬件内存大小 M 使得 B_c = 2, B_r = 2。
图一:逐行分块计算注意力
-
划分Q、K、V:
Q被划分为T_r = ⌈8 / 2⌉ = 4块,每块大小2 × 4。K和V被划分为T_c = ⌈8 / 2⌉ = 4块,每块大小2 × 4。
-
外层循环:总共有
T_c = 4个块,表示需要遍历K和V的 4 个块。- 第一次迭代加载
K_1和V_1到SRAM。
- 第一次迭代加载
-
内层循环:总共有
T_r = 4个块,表示需要遍历Q的 4 个块。- 第一次迭代加载
Q_1,O_1,L_1,m_1到SRAM。 - 在片上计算
S_11 = Q_1 K_1^T,得到一个2 × 2的矩阵。( 备注:内层循环 4 次得到 2*8) - 计算 以防止数值溢出,并对每行进行归一化
- 然后,进行归一化、更新
O_1和L_1。
- 第一次迭代加载
-
更新:
- 重复上述步骤直到遍历完所有的
K和V块,逐步完成O的计算。
- 重复上述步骤直到遍历完所有的
2.2 分块计算max最大值补偿
分块计算,导致局部不能或得每行的最大值 ,因此在算法设计上需要针对局部能够或得 整行的最大值进行计算补偿,主要分为两方面
-
- 历史局部逐行累计和 L 的补偿
-
- 历史 的补偿
2.2.1 历史局部逐行累计和 L 的补偿
首先我们看softmax
L 本质是逐行分块sum, 的具体公式如下
i=0 j=0 时 , 历史上述公式可以简化为(备注:为了不歧义,求和的 j 替换成 r)
i=0 j=1 时, 为了让历史的 (即上一次的)计算准确,我们需要针对
上一次计算进行数学处理乘以特定系数 W,以保证最新的 计算是无误的
这里的即历史的最大 , 实时最新的为
基于指数运算的本质 转换成 ,因此只需要
即 系数为 依此类推,因此,每次只要每行新块计算时乘以 该系数,则不会影响最新块的逐行sum结果,即
O_i 的每次合并也是同样需要针对局部最大值补偿
2.2.2 历史 的补偿
我先看 的计算公式
为了易于理解我们先介绍一下的功能作用,以方便后续去除简化公式 这里表示对角线全部是 的矩阵,核心功能是对右侧的矩阵进行缩放
下面我们通过一个例子理解基于对角矩阵的缩放
所以 核心 为对结果进行归一化,即结果浮点数控制在 1 以内,这里为了方便理解我们移除掉,原有公式简化为:
参考上文当中的图一:逐行分块计算注意力 我们可以看到 分块 计算的 都是 softmax(P_ij)*Vj 局部求和的一部分,下面是示意图(仅仅示意非全部操作)
i=0,j=0 时,历史m_0值为 和 上文公式可以简化为
对于 i=0,j=1 的时候
即 由上一个块计算完的 补偿 和归一化 后 和 最新块计算的结果求和得到最新的
对于前半部分 , 带入之前的 等价于
即对上一次的 进行补偿 局部 max 值计算导致的误差问题, 前半部分 统一乘以 是为了还原之前的计算,最后重新基于 重新归一化
2.2.3 补偿的一些小问题探讨
上述所有难理解的公式全部是由于补偿局部 max 和 重新归一化 norm 进行的数学操作,对于下图当中的操作由于外层循环为 所以每次的 每行最新的 m 需要 L 都需要读写 HBM (依据算法实现流程)
如果合理分配任务且外层循环 Q,内层 key 则可以避免频繁的 m 和 L 的存储读取操作, 对于缩放操作由于Q 修改成外循环,则上述的缩放操作可以最后进行,节省 1D 算力,以上这些都是可以继续优化的
3.总结
FlashAttention 从 128 到 2K 的常见序列长度范围内,与标准注意力实现相比,快至 30 倍。并且可以扩展到 64K。在序列长度为 512 的情况下,FlashAttention 不仅在速度上更快,而且在内存效率上也优于任何现有的注意力方法。而对于超过 1K 的序列长度,一些近似注意力方法(例如 Linformer)开始变得更快。另一方面,块稀疏 FlashAttention 比我们所知的所有现有近似注意力方法都要快
综上整体文档归纳总结,FlashAttention V1目的不是节约FLOPs,而是减少对HBM的访问,将复杂度由 降为了次, 重点是FlashAttention在训练和预测过程中的结果和标准Attention一样,对用户是无感的,而其他加速方法做不到这点
参考文献: