Flash Attention
一、产生背景
出自论文《Fast and Memory-Efficient Exact Attention with IO-Awareness》
核心是优化传统计算attention时显存内部的IO次数,从而增加整体计算速度。
二、传统计算过程
计算Attention的时候,有多次io操作,以下是计算流程。
- 开始Q K V 是在HBM中存储的
- 计算时候先把Q和K加载到SRAM中计算出S=QK^T
- 再把S写入HBM中,边计算边写出(因为S存不下,S的形状是[N,H,T,T],显存占用和token的个数的平方有关)
- 把S加载到SRAM,
- 计算P=softmax(S),边加载S边计算P,因为softmax是每T个值进行计算的。可以放的下
- 将P写出到HBM,边计算边写出
- 从HBM加载P和V到SRAM
- 计算O=PV
- 把O写出到HBM,边计算边写出
- 返回O。
三、Flash Attention计算过程
-
以下计算Attention是忽略softmax后的,也就是的计算。
- 开始Q K V 是在HBM中存储的,把Q K V进行理论上的分块,每次计算时,只加载对应的块。
- 计算时候先把一部分Q K V加载到SRAM中计算S=QK^T,此时不需要把S写出到HBM中,因为这里装的下。
- 然后计算SV,注意这里S只是Q和部分K的相关性, 直接计算SV得到的结果是,这部分V的加权合并。然后把结果写出到HBM
- 此时加载下一部分Q到SRAM中,和上部分的K计算,然后和上部分的V计算,把计算过写入到HBM中。不断循环Q而K V不动即可完成第一轮的计算
- 然后Q开始加载第一部分,K和V加载第二部分,进行计算得到结果,此时把HBM中结果的第一部分加载进行相加,然后写出HBM即可
- 此时Q进行加载下一部分,K V不动。 不断训练即可。
- 重复以上过程。核心就是两层for循环,一层是k和v的,下边一层是q的。
- 以上是不加入softmax的情况,如果有softmax计算,则上述不成立,因为softmax是对上述的S进行计算的,而上述的S每一行是不完整的也就不能顺利的进行计算。
-
softmax优化为safe-softmax
-
正常的softmax是,会有个数值溢出问题,fp16最大值是65536,为162754。
-
改进后的softmax是分子和分母都除以,此最后结果不变,公式化简为
-
而m的值取所有x中的最大值,此时每个减去他的最大值后,取值范围是小于等于0的,即的最大值为1。此时数值溢出问题就解决了,而同时公式再次进行推导后,可以解决flash attention中出现的问题
-
在flash attention中,每一小块的S,也就是q和k计算后的结果,是不完整的,此时我们可以先算每个部分的safe-softmax公式中的分母,然后保存起来,同时把各个小块的最大值也保存起来。
-
最后每个小块的S计算完毕后,进行再次处理,就可以计算完整的safe-softmax了。
-
-
safe-softmax分块拆解公式推导
核心是对每一个小块S计算一个值进行保留。等各个小块S算完后,再进行合并即可。
-
把x分为两块,
-
分别计算和的最大值为和
-
分别计算和对应safe-softmax的分母部分和
-
以上、、、在计算完毕后保存起来。每一行的x只需要保存4个值即可,极大减少了参数量
-
开始进行各块的safe-softmax值。
- 计算最终safe-softmax的分母,目的是乘以各系数,还原成原始的safe-softmax,也就是每个x减去全局的最大值。此时
- 分别计算最终safe-softmax的分子
- 此时就可以算出来各个部分的最终safe-softmax值了。
-
-
具体计算过程如下