1、原始Seif-Attention有什么问题?
Self-Attention的时间和空间复杂度都是二次方 -->序列过长,算法速度会变慢,消耗很高的内存
2、Flash-Attention
动机:加速注意力计算并减少内存开销
原理:核心思想是通过精细的I/O感知设计,特别是切块(Tiling)技术,来显著减少GPU HBM(高带宽内存接口,容量大速度慢)与高速SRAM(静态随机访问存储器,容量小速度快)之间的数据传输,其目标是将计算密集的部分尽可能限制在SRAM内完成,避免HBM成为瓶颈。将矩阵乘法、缩放、softmax 等操作合并成一个 GPU 内核,减少开销。
切块与SRAM内计算原理(重计算):由于完整的QKVO及中间的N×N矩阵无法全部放入SRAM,FlashAttention将输入QKV矩阵沿序列长度分割为更小的“块”,但会导致softmax操作无法一次性获取整行数据。采用“在线安全softmax”,在处理每个块时,会迭代更新已处理部分的最大值m和softmax分母的累积项l。当计算一个新的块时,会得到该块的局部最大值并与之前迭代的m比较,更新全局的m。同时,由于m可能发生变化,之前计算的累积分母l需要乘以一个补偿因子,再加上当前块基于新m计算出的指数和。通过这种方式,即使是分块处理,最终得到的softmax结果与用完整行应用安全softmax完全一致,且避免了存储整个分数矩阵来进行softmax。
3、KV-cache原理
自回归的特点,模型在每个时刻的输入都会拼接之前时间步的输出。Self-Attention主要计算在qkv部分,由于每一个时间步的输入都对之前的输出做了拼接,产生了大量的重复计算。KV-cache使用空间换时间的方式,每个时间步计算时将后面还需要用到的、计算结果缓存起来,减少重复计算。
4、推理的两阶段
预填充:发生在计算输出第一个token过程中,此时Cache为空。
生成:后续token的生成过程。
如何更新KV-cache:
计算next token:
5、为什么没有Q-cache?
每个token生成都只依赖前一个q和之前所有的KV。
6、KV-cache显存占用计算
batch_size序列长度头数多头特征维度layers22(float16每个占2字节)
7、MQA、GQA
MQA(图右):所有Q都共用一个KV,极大地减缓了需要存储的KV数量,但是效果损失严重。
GQA(图中):将Q进行分组,同组共享一个KV,进行了折中,典型代表是QWen。
8、为什么降低KV Cache的大小如此重要?
众所周知,一般情况下LLM的推理都是在GPU上进行,单张GPU的显存是有限的,一部分我们要用来存放模型的参数和前向计算的激活值,这部分依赖于模型的体量,选定模型后它就是个常数;另外一部分我们要用来存放模型的KV Cache,这部分不仅依赖于模型的体量,还依赖于模型的输入长度,也就是在推理过程中是动态增长的,当Context长度足够长时,它的大小就会占主导地位,可能超出一张卡甚至一台机(8张卡)的总显存量。
在GPU上部署模型的原则是:能一张卡部署的,就不要跨多张卡;能一台机部署的,就不要跨多台机。这是因为“卡内通信带宽 > 卡间通信带宽 > 机间通信带宽”,由于“木桶效应”,模型部署时跨的设备越多,受设备间通信带宽的的“拖累”就越大,事实上即便是单卡H100内SRAM与HBM的带宽已经达到了3TB/s,但对于Short Context来说这个速度依然还是推理的瓶颈,更不用说更慢的卡间、机间通信了。
所以,减少KV Cache的目的就是要实现在更少的设备上推理更长的Context,或者在相同的Context长度下让推理的batch size更大,从而实现更快的推理速度或者更大的吞吐总量。当然,最终目的都是为了实现更低的推理成本。