如果你最近刷到过“FlashAttention”,那你一定见过那句经典介绍:“它让传统 O(N²) 的 Attention,显存占用变成 O(N)。”
很多人平时也都用FlashAttention,但是很少有人能够讲清楚其中的原理。
今天我们就拆开讲清楚:
- 为什么普通 Attention 显存爆炸;
- FlashAttention 究竟改了什么;
- 为什么它能在保持 O(N²) 计算量的同时,让显存线性化。
一、普通 Attention 的计算与内存瓶颈
标准的自注意力(Self-Attention)计算如下:
假设输入序列长度为 N,特征维度为 d。
那么计算步骤:
- 计算相似度矩阵
S = QKᵀ → [N, N]
- 归一化
A = softmax(S)
- 加权求和
O = A * V
显存问题出在哪?
关键在于那一步 S = QKᵀ。
它是一个 N×N 的矩阵,会直接占据 O(N²) 的显存。
举个例子:
假设 N=4096,单精度浮点数 4 字节:
4096² × 4B ≈ 64 MB
而在多头 attention、batch 堆叠后,这个数会直接上百 MB。
再加上中间 softmax 的缓存与梯度,整个过程几乎炸显存。
二、FlashAttention 的核心思想
FlashAttention 的核心不是改公式,而是改计算顺序。论文题目里那句关键话非常准确:“An IO-aware exact attention algorithm.”
也就是说:
- 数学上结果一模一样;
- 但计算顺序被重排,
- 以最小化显存访问和缓存中间矩阵为目标。
普通实现流程:
QKᵀ → Softmax → Dropout → (Softmax * V)
问题是:
- 每一步都需要完整的 [N, N] 矩阵;
- 每层都要读写显存(global memory);
- Softmax 的数值稳定性还要额外缓存
max与sum。
这些中间值不是算力瓶颈,而是IO 瓶颈。
GPU 大部分时间都在“搬运数据”,而不是“算”。
三、FlashAttention 的关键优化
FlashAttention 的思路非常巧妙:把 Attention 计算拆成小块(tiles),每次只在显存中保留局部块,并在块级别完成 softmax 的归一化与累加。
分块计算 QKᵀ
把 Q 和 K 按块划分:
Q = [Q₁, Q₂, ..., Q_M]
K = [K₁, K₂, ..., K_M]
对于每个 query 块 Qᵢ:
- 依次读取每个 key 块 Kⱼ;
- 计算局部相似度矩阵 Sᵢⱼ = QᵢKⱼᵀ;
- 同时在寄存器中保留该块的最大值与和。
这样只需要存储一个 tile 的中间矩阵(比如 64×64),不会生成完整的 [N, N] 矩阵。
块内 Softmax 的数值稳定处理
为了保持数值精度,FlashAttention 在块内维护:
- 当前最大值
mᵢ; - 累积和
lᵢ。
公式如下:
这样,在不保存全局 S 的情况下,也能正确计算 softmax 归一化。
同步加权求和
每计算完一个块:
所有块处理完之后,就得到了完整的输出 Oᵢ。
整个过程是 流式的(streaming):
- 一边计算,一边归一化;
- 中间结果立刻被消费;
- 不需要缓存完整 attention 矩阵。
四、显存线性化的本质
普通 Attention:
- 必须保存 O(N²) 的相似度矩阵;
- 所以显存复杂度是 O(N²)。
FlashAttention:
- 只保存 O(N) 的输入输出(Q, K, V, O);
- 中间矩阵被分块并立即释放;
- 显存复杂度降为 O(N)。
计算量仍然是 O(N²),但显存访问和缓存规模线性化了。
简而言之,FlashAttention 不是降低计算复杂度,而是降低内存访问复杂度。
五、梯度计算也能高效吗?
梯度计算中,FlashAttention 也优化了反向传播。
它同样采用流式重计算(recompute):
- 前向不保存完整中间激活;
- 反向时重新计算需要的局部块;
- 减少显存峰值,但增加少量算力消耗。
这种设计非常适合训练大模型,因为 GPU 的主要瓶颈往往是显存,而不是算力。
FlashAttention v2采用了更高并行度 + kernel 调度来提升吞吐率,v3支持FP8、序列并行、多 query 批融合,进一步提速并适配大模型推理。如果想详细了解FlashAttentionV2 V3的详细算法和思想,文章末尾有专门分析它们的文章。
FlashAttention的精妙之处不在数学,而在工程调度。
它通过分块(tiling)计算、流式(streaming)softmax和kernel 融合(fusion),让原本需要 O(N²) 显存的注意力计算,在保持 O(N²) 计算量的同时实现了 显存 O(N) 的线性化。
📚推荐阅读
FlashAttention2:更快的注意力机制,更好的并行效率