DeepSeek开源周-Day01之DeepSeek FlashMLA 深度解析

25 阅读12分钟

DeepSeek 开源周 Day01 FlashMLA 深度解析

DeepSeek的开源周可谓是给AI界点燃了一根“加速火箭”,一开场便引爆了全球开发者的热情!在北京时间周一上午九点,DeepSeek豪气冲天地推出了本次开源周的“压轴好戏”——FlashMLA,这个项目一出,GitHub上的Star数就一路飙升,早已突破5000大关

image.png

在大语言模型推理中,序列解码的速度决定了你的AI是不是“急先锋”。DeepSeek早在V2版本中就引入了MLA,这一注意力机制的变种利用低秩KV压缩技术,有效缓解了KV Cache内存吃饱了撑的尴尬;到了V3版本,则又搭配上了DeepSeekMoE,简直是双剑合璧,既提速又省钱

FlashMLA借鉴了Flash Attention 2/3和英伟达cutlass库的绝妙设计思路,专门针对H800 SXM5等Hopper架构的GPU做了“特别定制”。在内存带宽方面,它能达到3000 GB/s,几乎贴近H800 SXM5理论极限的3350 GB/s;计算性能则高达580 TFLOPS,冲击了理论峰值的87%!这就像是给GPU打了一针强心剂,接下来我们由浅入深的分析一下 FlashMLA

什么是MLA

DeepSeek-V3 采用了多头潜变量注意力(Multi-Head Latent Attention, MLA)架构。 MLA 机制的核心思想是对注意力的键(Key)和值(Value)进行低秩联合压缩(low-rank joint compression) ,以减少推理时的键-值(KV)缓存,从而降低计算成本,同时保持接近标准多头注意力(Multi-Head Attention, MHA)的性能,所以这里的核心是低秩联合压缩,由于 FlashMLA 借鉴了 Flash Attention V2 和 V3 版本的优化技术,所以下面我们简单介绍一下历史背景,毕竟吃水不忘挖井人。

image.png

MHA (多头注意力 活好钱贵)

注意力机制是 Transform 当中很重要的网络结构组成 image.png

其核心公式:

image.png **总结**

MHA通过多个头的方式,可以增强自注意力机制聚合上下文信息的能力,以关注上下文的不同侧面,作用类似于CNN的多个卷积核, 你可以理解多个注意力的组合,每个注意力计算类似一个 CNN 的卷积核

image.png

Flash Attention

标准 Attention

Attention(Q,K,V)=softmax(QKdk)V\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q K^\top}{\sqrt{d_k}} \right) V

对于标准Attention 其核心计算分为4 步

企业微信截图_1f6e3b79-39e8-48da-a085-ccad86c58ef4.png

其核心链路涉及

  • 第一步计算 中间结果 S, 即 S=QKTS=QK^T 写回到 HBM
  • 第二部计算 中间结果 P, 即 P=softmax(S)P=softmax(S) 写回到 HBM
  • 第三步计算 最终输出结果O, 即 O=PVO=PV 写回到 HBM
  • 第四步 返回结算结果

标准Attention的中间结果S,P 通常需要通过高带宽内存(HBM)进行存取,两者所需内存空间复杂度为O(N2)O(N^2)

FlashAttention

Flash Attention计算 有效减少了对全局内存(HBM)的访问需求。通过优化数据传输和利用片上高速缓存, 并改进算法降低了内存带宽的需求并提高了计算效率,下面我们简单粗粒度基于 FlashAttentionV2 和 V3 简单说明一下针对 MHA 的极致优化,详细内容请访问具体链接

FlashAttention V1

V1 应用了两种已确立的技术(titling(分块),Recomputation(重计算降低内存存储))来克服在次二次 HBM 访问中计算精确注意力的技术挑战,流程示意图如下  企业微信截图_c44e313f-f789-4d87-8978-a6e3dcbcc8f7.png  图一:逐行分块计算注意力

FlashAttention V2

Flash Attention V1 如果外循环修改为基于 QjQ_j, 则每个warp 可以连续针对 OjO_j 连续处理,可以有效避免中间变量的 HBM 保存和加载,同时针对 O 最后进行归一化缩放,避免局部对 OjO_j 的每次 1D 缩放操作,具体示意参考下图:

企业微信截图_f174fa81-70c8-4f8a-809b-4b4a27364d5a.png

新的算法可以将 每次分块内循环OjO_{j} 的归一化补偿统一最后一次进行,V2 调整了 FlashAttention 算法,以减少非矩阵乘运算的 FLOPs 数量。这是因为现代 GPU(例如 Nvidia GPU 上的张量核心)专门用于加速矩阵乘运算

FlashAttention V3

FlashAttention V1 和 FlashAttention V2 提出了一种通过减少内存读写来加速 GPU 上Attention计算的方法。然而,它尚未充分利用最近硬件提供的新能力,比如 FlashAttention-2 在 H100 GPU 上的利用率仅为 35%。Flash Attention V3它贡献了三个新的想法,以进一步提高在最新 GPU 架构上的性能:

  1. 利用TensorCore和 TMA 的异步性,通过专门的 WarpGroup来实现数据搬运和计算的并行(pingpong),以及矩阵块乘法与 softmax 操作在 K,V维度上的进一步并行和数据拆分, 来达到整体计算和数据移动二者的时间重叠 image.png

  2. 生产者-消费者异步性:定义了一种基于异步执行数据搬运和TensorCore的 warp 级别软件流水线方案,通过将数据的生产者和消费者分别划分到不同的WarpGroup中,以此扩展算法隐藏内存和指令发射延迟的能力

  3. 引入 FP8 低精度量化(需要 H100 支持),利用硬件对 FP8 低精度的支持进行块量化和无相关处理

FlashAttention-3 基于 H100 硬件新特性设计的优化方案,在 H100 GPU 上FP16精度实现 1.5-2.0 倍的加速,达到 740 TFLOPs/s(利用率 75%),FP8精度 达到接近 1.2 PFLOPs/s。通过实验验证,FP8 FlashAttention-3 的数值误差比基线 FP8 注意力低 2.6倍

FlashMLA是专门为英伟达Hopper GPU量身定做的高效MLA(Multi-Head Latent Attention)解码内核,其目标就是让推理速度嗖嗖上升,尤其是在H100等Hopper GPU上,简直像给模型装上了火箭推进器

MLA 机制原理

DeepSeek-V3 仍然基于 Transformer 框架(Vaswani et al., 2017),采用了 MLA 和 DeepSeekMoE 以提升推理和训练效率。

image.png

MLA 针对 KV 的低秩压缩核心算法 如下:

注意下图公式中: D 是 Down 下投影 U 是 Up 上投影的 意思 image.png

  • dd:嵌入维度, dhd_h:单头维度

  • htRdh_t \in \mathbb{R}^d:第 tt 个 token 在该层的输入, nhn_h:注意力头数

  • ctKVRdc c^{KV}_t \in \mathbb{R}^{d_c} 为键值的压缩潜变量表示,用于针对输入进行转换

  • WDKVRdc×dW^{DKV} \in \mathbb{R}^{d_c \times d} 为降维变换矩阵

  • dcdcdhnhd_c( d_c ≪d_h n_h) 为 KV 压缩维度

其中: MLA 通过引入潜在表示 ctKV=WDKVht​​c^{KV}_t =W^{DKV} h_t​​ 来压缩 键和值:

接着通过上投影(up-projection)恢复键和值 其中 WUK,WUVRdh,nh×dc W^{UK}, W^{UV} \in \mathbb{R}^{d_h,n_h \times d_c}分别为键和值的上投影矩阵: 如上图中公式 (2)和 (5),其中:

此外,MLA 还引入了一种独立的键 ktRk_t^R,用于携带旋转位置编码 RoPE 如上图中公式 (3)

优化点:在推理过程中,仅需缓存 蓝框标出的向量(即ctKV c^{KV}_tktRk_t^R),从而显著减少 KV 缓存需求,同时在性能上可与标准多头注意力(MHA)相媲美

低秩压缩的查询(Query)计算

同样地,对于查询(Query),MLA 采用低秩压缩,以减少训练过程中**激活(activation)**的内存占用

image.png

其中:注意 D 是 Down 下投影 U 是 Up 上投影的 意思

  • 上图中公式(6) ctQRdcc^Q_t \in \mathbb{R}^{ d'_c} 为查询的低秩压缩表示
  • dc(dh,nh)d_c^′≪(d_h,n_h) 为查询压缩维度
  • WDQRdc×d,WUQRdhnh×dcW^{DQ} \in \mathbb{R}^{d'_c \times d}, W^{UQ} \in \mathbb{R}^{d_h n_h \times d'_c} 为查询的降维和上投影矩阵
  • WQRRdhRnh×dcW_{Q_R} \in \mathbb{R}^{d^R_h n_h \times d'_c} 为生成带 RoPE 的查询向量的投影矩阵

所以

  • 公式 6 为低秩压缩
  • 公式 7 为 上投影还原 query
  • 公式 8 基于压缩后的 ctQc^Q_t 生成带 RoPE 的上投影矩阵
  • 公式 9 联合位置信息

MLA 计算最终注意力输出

image.png

最终,MLA 的注意力计算由 查询:qt,iq_{t,i}、键:kj,i和值vi,jCk_{j,i} 和值v^C_{i,j}组成:

MLA 的优势

  • 降低 KV 缓存成本:相比标准 MHA,MLA 只缓存 ctKVc^{KV}_tktRk_t^R,减少存储占用。
  • 提升推理效率:MLA 在保持与 MHA 相当的性能下,大幅优化了计算和存储效率,特别适用于 长序列推理任务

FlashMLA 的横空出世

DeepSeek选择开源,仿佛在向全世界宣布:不写高大上的论文,而是直接把能跑能用的代码送给大家。FlashMLA的开源不仅展示了DeepSeek的技术实力,更体现了他们推动整个AI生态圈开放协作的豪迈胸襟

image.png

核心逻辑 1: 产生 MLA_metadata

核心 kernel 功能逻辑

  • 1.计算每个输入序列的分块数(4096 ÷ 64 = 64)并加上固定开销(5),得到每个样本需要调度 69 个块,总和 8832。

  • 2.根据 GPU 的 SM 分区数(144)计算每个分区应承担的负载(67)。

  • 3.采用内核中嵌套循环的方式,将 128 个样本(可能是完整或部分)按负载均匀分配到 144 个 SM 分区上,同时记录每个分区的起止序列及偏移信息,以及每个样本的累计分割计数。

  • 4.最后,生成两个输出张量:一个存放调度元数据(tile_scheduler_metadata,形状大致为 [144, TileSchedulerMetaDataSize]),另一个存放分割信息(num_splits,形状为 [129]) 从 python 视角整体流程示意图如下:

image.png 以下是部分 SM 分区(tile)的调度推导示例:

  • SM 分区 0 (i = 0)

    • 初始状态:now_idx = 0, now_block = 0, remain_payload = 67。

    • 对序列 0:需要分配 64 块,但 67 < 64+5,因此无法完成,分配部分:分配了 (67–5)=62 个块,更新 now_block = 62,now_n_split_idx = 1。

    • 写入调度元数据:

      • start_idx = 0
      • start_offset = 0×64 = 0
      • end_idx = 0 (因为 now_block > 0,表示当前 tile 仍在序列 0 内)
      • end_offset = 62×64 = 3968
      • split_count = 1
    • SM0 的元数据行为:[0, 0, 0, 3968, 1].

  • SM 分区 1 (i = 1)

    • 状态传递:now_idx = 0, now_block = 62, now_n_split_idx = 1。

    • 初始写入:start_idx = 0, start_offset = 62×64 = 3968, split_count = 1, remain_payload 重置为 67。

    • 处理序列 0:剩余块数 = 64–62 = 2,此时 67 ≥ (2+5)=7,可完成序列 0。

      • 更新:cum_num_splits 增加 (1+1)=2,remain_payload 变为 67–7=60,now_idx 更新为 1,重置 now_block=0, now_n_split_idx=0。
    • 继续 while 循环处理序列 1:对于序列 1,需 64 块,但 60 < 64+5,部分分配 (60–5)=55 块,更新 now_block=55, now_n_split_idx=1。

    • 写入当前 tile 数据:

      • start_idx = 0
      • start_offset = 3968
      • end_idx = 1 (因为当前仍在序列 1 部分分配)
      • end_offset = 55×64 = 3520
      • split_count = 1
    • SM1 的元数据行为为:[0, 3968, 1, 3520, 1].

  • SM 分区 2 (i = 2)

    • 初始状态:now_idx = 1, now_block = 55, now_n_split_idx = 1。

    • 写入:start_idx = 1, start_offset = 55×64 = 3520, split_count = 1, remain_payload = 67。

    • 处理序列 1:剩余块数 = 64–55 = 9,67 ≥ 9+5=14,完成序列 1。

      • 更新:cum_num_splits 加 (1+1)=2(累计变为 4),remain_payload 变为 67–14=53,now_idx 变为 2,重置 now_block=0, now_n_split_idx=0。
    • 继续处理序列 2:对序列 2,64 块全量需求,53 < 64+5,部分分配 (53–5)=48 块,更新 now_block=48, now_n_split_idx=1。

    • 写入:end_idx = 2, end_offset = 48×64 = 3072。

    • SM2 的元数据行为为:[1, 3520, 2, 3072, 1].

  • SM 分区 3 (i = 3)

    • 初始状态:now_idx = 2, now_block = 48, now_n_split_idx = 1。

    • 写入:start_idx = 2, start_offset = 48×64 = 3072, split_count = 1, remain_payload = 67。

    • 对序列 2:剩余块数 = 64–48 = 16,67 ≥ 16+5=21,完成序列 2。

      • 更新:cum_num_splits 加 (1+1)=2(累计变为 6),remain_payload 变为 67–21=46,now_idx 变为 3,重置 now_block=0, now_n_split_idx=0。
    • 继续处理序列 3:对序列 3,64 块需求,46 < 64+5,部分分配 (46–5)=41 块,更新 now_block=41, now_n_split_idx=1。

    • 写入:end_idx = 3, end_offset = 41×64 = 2624。

    • SM3 的元数据行为:[2, 3072, 3, 2624, 1].

  • .....

  • SM 分区 143

    • 最后一个分区在 tile_scheduler_metadata_ptr 内对应的行数据为:[128, 0, 127, 4096, 0]

核心逻辑 2:MLA 计算

核心逻辑是基于 SM 的任务切分,切分好了并行计算MLA注意力 核心思想可以参考 Flash Attention V2 V3, 这里不在累述细节

企业微信截图_b4b5d3eb-3d3b-45fa-a9b8-8553826d01a5.png

总结

FlashMLA借鉴了Flash Attention 2/3和英伟达cutlass库的绝妙设计思路,专门针对H800 SXM5等Hopper架构的GPU做了“特别定制”。在内存带宽方面,它能达到3000 GB/s,几乎贴近H800 SXM5理论极限的3350 GB/s;计算性能则高达580 TFLOPS,冲击了理论峰值的87%!MLA通过分块KV计算、共享内存优化和高效线程协作,显著降低显存需求并提升长序列处理速度。其核心创新点在于:

  1. 分块Softmax的渐进式归约,避免全局显存依赖。
  2. 双缓冲与Warp级并行,最大化GPU利用率。
  3. 动态负载均衡:通过tile_scheduler_metadata动态分配块到SM,适应变长序列。

大模型加速-核心网络算子-Flash Attention V3

大模型加速-核心网络算子-Flash Attention V2

同系列