DeepSeek 开源周 Day01 FlashMLA 深度解析
DeepSeek的开源周可谓是给AI界点燃了一根“加速火箭”,一开场便引爆了全球开发者的热情!在北京时间周一上午九点,DeepSeek豪气冲天地推出了本次开源周的“压轴好戏”——FlashMLA,这个项目一出,GitHub上的Star数就一路飙升,早已突破5000大关
在大语言模型推理中,序列解码的速度决定了你的AI是不是“急先锋”。DeepSeek早在V2版本中就引入了MLA,这一注意力机制的变种利用低秩KV压缩技术,有效缓解了KV Cache内存吃饱了撑的尴尬;到了V3版本,则又搭配上了DeepSeekMoE,简直是双剑合璧,既提速又省钱
FlashMLA借鉴了Flash Attention 2/3和英伟达cutlass库的绝妙设计思路,专门针对H800 SXM5等Hopper架构的GPU做了“特别定制”。在内存带宽方面,它能达到3000 GB/s,几乎贴近H800 SXM5理论极限的3350 GB/s;计算性能则高达580 TFLOPS,冲击了理论峰值的87%!这就像是给GPU打了一针强心剂,接下来我们由浅入深的分析一下 FlashMLA
什么是MLA
DeepSeek-V3 采用了多头潜变量注意力(Multi-Head Latent Attention, MLA)架构。 MLA 机制的核心思想是对注意力的键(Key)和值(Value)进行低秩联合压缩(low-rank joint compression) ,以减少推理时的键-值(KV)缓存,从而降低计算成本,同时保持接近标准多头注意力(Multi-Head Attention, MHA)的性能,所以这里的核心是低秩联合压缩,由于 FlashMLA 借鉴了 Flash Attention V2 和 V3 版本的优化技术,所以下面我们简单介绍一下历史背景,毕竟吃水不忘挖井人。
MHA (多头注意力 活好钱贵)
注意力机制是 Transform 当中很重要的网络结构组成
其核心公式:
MHA通过多个头的方式,可以增强自注意力机制聚合上下文信息的能力,以关注上下文的不同侧面,作用类似于CNN的多个卷积核, 你可以理解多个注意力的组合,每个注意力计算类似一个 CNN 的卷积核
Flash Attention
标准 Attention
对于标准Attention 其核心计算分为4 步
其核心链路涉及
- 第一步计算 中间结果 S, 即 写回到 HBM
- 第二部计算 中间结果 P, 即 写回到 HBM
- 第三步计算 最终输出结果O, 即 写回到 HBM
- 第四步 返回结算结果
标准Attention的中间结果S,P 通常需要通过高带宽内存(HBM)进行存取,两者所需内存空间复杂度为。
FlashAttention
Flash Attention计算 有效减少了对全局内存(HBM)的访问需求。通过优化数据传输和利用片上高速缓存, 并改进算法降低了内存带宽的需求并提高了计算效率,下面我们简单粗粒度基于 FlashAttentionV2 和 V3 简单说明一下针对 MHA 的极致优化,详细内容请访问具体链接
FlashAttention V1
V1 应用了两种已确立的技术(titling(分块),Recomputation(重计算降低内存存储))来克服在次二次 HBM 访问中计算精确注意力的技术挑战,流程示意图如下
图一:逐行分块计算注意力
FlashAttention V2
Flash Attention V1 如果外循环修改为基于 , 则每个warp 可以连续针对 连续处理,可以有效避免中间变量的 HBM 保存和加载,同时针对 O 最后进行归一化缩放,避免局部对 的每次 1D 缩放操作,具体示意参考下图:
新的算法可以将 每次分块内循环 的归一化补偿统一最后一次进行,V2 调整了 FlashAttention 算法,以减少非矩阵乘运算的 FLOPs 数量。这是因为现代 GPU(例如 Nvidia GPU 上的张量核心)专门用于加速矩阵乘运算
FlashAttention V3
FlashAttention V1 和 FlashAttention V2 提出了一种通过减少内存读写来加速 GPU 上Attention计算的方法。然而,它尚未充分利用最近硬件提供的新能力,比如 FlashAttention-2 在 H100 GPU 上的利用率仅为 35%。Flash Attention V3它贡献了三个新的想法,以进一步提高在最新 GPU 架构上的性能:
-
利用TensorCore和 TMA 的异步性,通过专门的 WarpGroup来实现数据搬运和计算的并行(pingpong),以及矩阵块乘法与 softmax 操作在 K,V维度上的进一步并行和数据拆分, 来达到整体计算和数据移动二者的时间重叠
-
生产者-消费者异步性:定义了一种基于异步执行数据搬运和TensorCore的 warp 级别软件流水线方案,通过将数据的生产者和消费者分别划分到不同的WarpGroup中,以此扩展算法隐藏内存和指令发射延迟的能力
-
引入 FP8 低精度量化(需要 H100 支持),利用硬件对 FP8 低精度的支持进行块量化和无相关处理
FlashAttention-3 基于 H100 硬件新特性设计的优化方案,在 H100 GPU 上FP16精度实现 1.5-2.0 倍的加速,达到 740 TFLOPs/s(利用率 75%),FP8精度 达到接近 1.2 PFLOPs/s。通过实验验证,FP8 FlashAttention-3 的数值误差比基线 FP8 注意力低 2.6倍
FlashMLA是专门为英伟达Hopper GPU量身定做的高效MLA(Multi-Head Latent Attention)解码内核,其目标就是让推理速度嗖嗖上升,尤其是在H100等Hopper GPU上,简直像给模型装上了火箭推进器
MLA 机制原理
DeepSeek-V3 仍然基于 Transformer 框架(Vaswani et al., 2017),采用了 MLA 和 DeepSeekMoE 以提升推理和训练效率。
MLA 针对 KV 的低秩压缩核心算法 如下:
注意下图公式中: D 是 Down 下投影 U 是 Up 上投影的 意思
设
-
:嵌入维度, :单头维度
-
:第 个 token 在该层的输入, :注意力头数
-
为键值的压缩潜变量表示,用于针对输入进行转换
-
为降维变换矩阵
-
为 KV 压缩维度
其中: MLA 通过引入潜在表示 来压缩 键和值:
接着通过上投影(up-projection)恢复键和值 其中 分别为键和值的上投影矩阵: 如上图中公式 (2)和 (5),其中:
此外,MLA 还引入了一种独立的键 ,用于携带旋转位置编码 RoPE 如上图中公式 (3)
优化点:在推理过程中,仅需缓存 蓝框标出的向量(即 和 ),从而显著减少 KV 缓存需求,同时在性能上可与标准多头注意力(MHA)相媲美
低秩压缩的查询(Query)计算
同样地,对于查询(Query),MLA 采用低秩压缩,以减少训练过程中**激活(activation)**的内存占用
其中:注意 D 是 Down 下投影 U 是 Up 上投影的 意思
- 上图中公式(6) 为查询的低秩压缩表示
- 为查询压缩维度
- 为查询的降维和上投影矩阵
- 为生成带 RoPE 的查询向量的投影矩阵
所以
- 公式 6 为低秩压缩
- 公式 7 为 上投影还原 query
- 公式 8 基于压缩后的 生成带 RoPE 的上投影矩阵
- 公式 9 联合位置信息
MLA 计算最终注意力输出
最终,MLA 的注意力计算由 查询:、键:组成:
MLA 的优势
- 降低 KV 缓存成本:相比标准 MHA,MLA 只缓存 和 ,减少存储占用。
- 提升推理效率:MLA 在保持与 MHA 相当的性能下,大幅优化了计算和存储效率,特别适用于 长序列推理任务
FlashMLA 的横空出世
DeepSeek选择开源,仿佛在向全世界宣布:不写高大上的论文,而是直接把能跑能用的代码送给大家。FlashMLA的开源不仅展示了DeepSeek的技术实力,更体现了他们推动整个AI生态圈开放协作的豪迈胸襟
核心逻辑 1: 产生 MLA_metadata
核心 kernel 功能逻辑
-
1.计算每个输入序列的分块数(4096 ÷ 64 = 64)并加上固定开销(5),得到每个样本需要调度 69 个块,总和 8832。
-
2.根据 GPU 的 SM 分区数(144)计算每个分区应承担的负载(67)。
-
3.采用内核中嵌套循环的方式,将 128 个样本(可能是完整或部分)按负载均匀分配到 144 个 SM 分区上,同时记录每个分区的起止序列及偏移信息,以及每个样本的累计分割计数。
-
4.最后,生成两个输出张量:一个存放调度元数据(tile_scheduler_metadata,形状大致为 [144, TileSchedulerMetaDataSize]),另一个存放分割信息(num_splits,形状为 [129]) 从 python 视角整体流程示意图如下:
以下是部分 SM 分区(tile)的调度推导示例:
-
SM 分区 0 (i = 0)
-
初始状态:now_idx = 0, now_block = 0, remain_payload = 67。
-
对序列 0:需要分配 64 块,但 67 < 64+5,因此无法完成,分配部分:分配了 (67–5)=62 个块,更新 now_block = 62,now_n_split_idx = 1。
-
写入调度元数据:
- start_idx = 0
- start_offset = 0×64 = 0
- end_idx = 0 (因为 now_block > 0,表示当前 tile 仍在序列 0 内)
- end_offset = 62×64 = 3968
- split_count = 1
-
SM0 的元数据行为:[0, 0, 0, 3968, 1].
-
-
SM 分区 1 (i = 1)
-
状态传递:now_idx = 0, now_block = 62, now_n_split_idx = 1。
-
初始写入:start_idx = 0, start_offset = 62×64 = 3968, split_count = 1, remain_payload 重置为 67。
-
处理序列 0:剩余块数 = 64–62 = 2,此时 67 ≥ (2+5)=7,可完成序列 0。
- 更新:cum_num_splits 增加 (1+1)=2,remain_payload 变为 67–7=60,now_idx 更新为 1,重置 now_block=0, now_n_split_idx=0。
-
继续 while 循环处理序列 1:对于序列 1,需 64 块,但 60 < 64+5,部分分配 (60–5)=55 块,更新 now_block=55, now_n_split_idx=1。
-
写入当前 tile 数据:
- start_idx = 0
- start_offset = 3968
- end_idx = 1 (因为当前仍在序列 1 部分分配)
- end_offset = 55×64 = 3520
- split_count = 1
-
SM1 的元数据行为为:[0, 3968, 1, 3520, 1].
-
-
SM 分区 2 (i = 2)
-
初始状态:now_idx = 1, now_block = 55, now_n_split_idx = 1。
-
写入:start_idx = 1, start_offset = 55×64 = 3520, split_count = 1, remain_payload = 67。
-
处理序列 1:剩余块数 = 64–55 = 9,67 ≥ 9+5=14,完成序列 1。
- 更新:cum_num_splits 加 (1+1)=2(累计变为 4),remain_payload 变为 67–14=53,now_idx 变为 2,重置 now_block=0, now_n_split_idx=0。
-
继续处理序列 2:对序列 2,64 块全量需求,53 < 64+5,部分分配 (53–5)=48 块,更新 now_block=48, now_n_split_idx=1。
-
写入:end_idx = 2, end_offset = 48×64 = 3072。
-
SM2 的元数据行为为:[1, 3520, 2, 3072, 1].
-
-
SM 分区 3 (i = 3)
-
初始状态:now_idx = 2, now_block = 48, now_n_split_idx = 1。
-
写入:start_idx = 2, start_offset = 48×64 = 3072, split_count = 1, remain_payload = 67。
-
对序列 2:剩余块数 = 64–48 = 16,67 ≥ 16+5=21,完成序列 2。
- 更新:cum_num_splits 加 (1+1)=2(累计变为 6),remain_payload 变为 67–21=46,now_idx 变为 3,重置 now_block=0, now_n_split_idx=0。
-
继续处理序列 3:对序列 3,64 块需求,46 < 64+5,部分分配 (46–5)=41 块,更新 now_block=41, now_n_split_idx=1。
-
写入:end_idx = 3, end_offset = 41×64 = 2624。
-
SM3 的元数据行为:[2, 3072, 3, 2624, 1].
-
-
.....
-
SM 分区 143
- 最后一个分区在 tile_scheduler_metadata_ptr 内对应的行数据为:[128, 0, 127, 4096, 0]
核心逻辑 2:MLA 计算
核心逻辑是基于 SM 的任务切分,切分好了并行计算MLA注意力 核心思想可以参考 Flash Attention V2 V3, 这里不在累述细节
总结
FlashMLA借鉴了Flash Attention 2/3和英伟达cutlass库的绝妙设计思路,专门针对H800 SXM5等Hopper架构的GPU做了“特别定制”。在内存带宽方面,它能达到3000 GB/s,几乎贴近H800 SXM5理论极限的3350 GB/s;计算性能则高达580 TFLOPS,冲击了理论峰值的87%!MLA通过分块KV计算、共享内存优化和高效线程协作,显著降低显存需求并提升长序列处理速度。其核心创新点在于:
- 分块Softmax的渐进式归约,避免全局显存依赖。
- 双缓冲与Warp级并行,最大化GPU利用率。
- 动态负载均衡:通过
tile_scheduler_metadata
动态分配块到SM,适应变长序列。
大模型加速-核心网络算子-Flash Attention V3
大模型加速-核心网络算子-Flash Attention V2
同系列