MiniMax M3 稀疏注意力架构:如何用 1/20 的算力跑赢全注意力 Transformer

7 阅读7分钟

引言

2026 年 6 月 1 日,MiniMax 发布新一代旗舰模型 M3,搭载自研的稀疏注意力架构 MSA(MiniMax Sparse Attention)。这不仅仅是一次模型迭代——MSA 将每个 Token 的计算量压缩到上一代 M2 的 1/20,在 100 万 Token 上下文中实现预填充 9 倍加速、解码 15 倍加速,同时能力基本不丢。这不是"挤牙膏"式的优化,而是一次对 Transformer 注意力范式的结构性重构。本文将深入拆解 MSA 的技术原理,理解它为什么能做到又快又好。

问题背景:全注意力的算力困境

标准 Transformer 的自注意力机制有一个被反复诟病的先天缺陷:计算复杂度 O(N2)O(N^2) 随序列长度平方级增长。当上下文从 4K 扩展到 1M(百万)Token 时,注意力矩阵的计算量膨胀到原来的 62500 倍。这意味着:

  • 推理成本爆炸:长上下文场景下的 KV Cache 占用和计算开销让大多数模型望而却步
  • 延迟不可接受:百万 Token 的预填充阶段可能耗时数十秒甚至分钟级
  • 显存墙:即使是 H100 这种级别的 GPU,也难以在单卡上放下百万 Token 的完整 KV Cache

业界对此已有多种稀疏注意力方案,如 Mixture of Attention(MoA)、Flash-Sparse-Attention、FlashMoBA 等。但它们的共同问题是:稀疏覆盖不够精确,导致有效上下文丢失;或者访存模式对 GPU 内存层次不友好,实际加速比远低于理论值

MiniMax 在上一代 M2 中曾回归全注意力机制,原因正是稀疏注意力的基础设施成熟度不足。M3 重新推进稀疏注意力,说明他们找到了一个"既稀疏又不丢能力"的解法——MSA。

技术原理:MSA 的双重架构设计

Index Branch + Sparse Branch 双路径机制

MSA 的核心设计思想是**"先索引,再计算"**——将注意力计算拆分为两个阶段:

  1. Index Branch(索引分支):轻量级网络快速扫描所有 KV 对,判断哪些 KV 块与当前 Query 相关。这一步的计算量极小,因为只做粗粒度的相关性判断,不执行完整的注意力计算。

  2. Sparse Branch(稀疏分支):根据 Index Branch 的筛选结果,仅对命中的 KV 块执行精确的注意力计算。

这类似于数据库中的"索引扫描 + 数据回表":先用 B+ 树索引快速定位目标行,再读取实际数据,避免全表扫描。MSA 的 Index Branch 就是注意力的"B+ 树索引"。

精确 KV 分块策略

MSA 的第二个关键创新是精确 KV 分块。与 DSA、MoBA 等现有方案相比,MSA 能够更精确地对 KV 进行分块,实现更高的有效上下文覆盖。具体来说:

  • 传统稀疏注意力方案往往采用固定大小的 KV 块,导致块内包含大量无关 Token,浪费计算
  • MSA 的分块策略考虑了语义边界,使得每个 KV 块内的信息更加紧凑相关

"KV Outer Gather Q"访存优化

这是 MSA 在工程层面的杀手锏。传统思路是"Query 去找 Key-Value"(Q outer gather KV),每个 Query 需要反复读取不同的 KV 块,导致大量随机访存。MSA 反其道而行:

  • 以 KV 块为外层循环,聚合命中该块的 Query
  • 每个 KV 块只读取一次,命中它的所有 Query 共享这一次读取
  • 访存模式从随机变为连续顺序读取,完美匹配 GPU 的内存层次结构

这种策略让 MSA 在当前 head 配比下的计算访存比显著优于主流方法,比开源的 Flash-Sparse-Attention 和 FlashMoBA 快 4 倍以上

伪代码示例

以下伪代码展示 MSA 的核心计算流程:

def msa_attention(Q, K, V, index_net):
    """
    MSA Sparse Attention 核心流程
    Q: [batch, num_heads, seq_len, head_dim]
    K, V: [batch, num_heads, seq_len, head_dim]
    """
    # Step 1: Index Branch - 粗粒度相关性判断
    # 将 KV 划分为固定大小的块
    kv_blocks = partition_kv(K, V, block_size=BLOCK_SIZE)  # [num_blocks, ...]

    # Index Net 快速评分,输出每个 Q 需要关注哪些 KV 块
    block_scores = index_net(Q, kv_blocks)  # [batch, heads, seq_len, num_blocks]
    topk_indices = torch.topk(block_scores, k=TOPK_BLOCKS).indices

    # Step 2: KV Outer Gather Q - 核心访存优化
    # 反转索引:从 "Q->KV" 变为 "KV->Q"
    kv_to_q_map = invert_index(topk_indices)
    # kv_to_q_map[block_i] = 命中 block_i 的所有 Q 的位置列表

    output = torch.zeros_like(Q)

    for block_i in range(num_kv_blocks):
        # 每个 KV 块只读取一次,连续访存
        K_block = kv_blocks[block_i].K  # [block_size, head_dim]
        V_block = kv_blocks[block_i].V

        # 找到所有命中当前 KV 块的 Query
        q_indices = kv_to_q_map[block_i]
        Q_hit = Q[:, :, q_indices, :]  # 仅加载相关 Query

        # 精确注意力计算(仅对命中的 Q-KV 对)
        attn_weights = torch.matmul(Q_hit, K_block.transpose(-2, -1))
        attn_weights = attn_weights / math.sqrt(head_dim)
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_output = torch.matmul(attn_weights, V_block)

        output[:, :, q_indices, :] += attn_output

    return output

关键点在于 invert_index 这一步——将"每个 Q 需要哪些 KV 块"的映射反转为"每个 KV 块被哪些 Q 命中",从而实现 KV 的一次读取、多次复用。

性能实测:数据说话

指标标准 TransformerMiniMax M3 (MSA)提升幅度
每 Token 计算量基准1/2020× 压缩
1M 上下文预填充速度基准9× 加速数量级提升
1M 上下文解码速度基准15× 加速数量级提升
SWE-Bench Pro-59.0%(超 GPT-5.5)开源 SOTA
CUDA 内核硬件峰值利用率7.6%71.3%9.4× 优化

值得注意的是,MSA 的 CUDA 内核优化过程本身就是一个 AI 工程奇迹——M3 连续自主工作 24 小时,完成 147 次 benchmark 提交和 1959 次工具调用,将 Hopper FP8 硬件峰值利用率从 7.6% 拉升到 71.3%。这表明 MSA 不仅仅是算法层面的创新,还在算子层面做了深度的针对性优化。

个人观点

MSA 的意义远超一次模型升级,它代表了行业对 Transformer 架构本身的反思:

第一,稀疏注意力正在从"妥协方案"变为"最优解"。 过去,稀疏注意力被视为全注意力的降级替代——速度快了,但能力丢了。MSA 证明,只要分块策略足够精确、访存优化足够到位,稀疏注意力可以在几乎不损失能力的前提下实现数量级的效率提升。这改变了我们对"稀疏"的预期。

第二,"KV Outer Gather Q"的逆向思维值得借鉴。 大多数稀疏注意力方案都在"Q 找 KV"的框架下优化,而 MSA 逆转了这个方向。这种工程视角的范式转换,可能是未来算子优化的一个重要方向——与其优化"找"的过程,不如改变"谁去找谁"。

第三,AI 优化 AI 的闭环正在形成。 M3 用 24 小时自主优化 CUDA 内核达到 9.4 倍加速,这不是巧合,而是一个趋势的缩影。当 AI 能比人类工程师更高效地做算子优化时,整个 AI Infra 的开发范式都会被重塑。

不过,MSA 仍然面临挑战:Index Branch 的质量直接决定了整体效果,如果索引判断出现系统性偏差,错误会在后续精确计算中被放大。此外,MSA 在超短上下文(<4K)场景下的优势不明显,因为全注意力的绝对计算量本身就很小。因此,MSA 更适合作为长上下文场景的专用优化,而非"一刀切"的全场景方案。

总结

MiniMax M3 的 MSA 架构通过"Index Branch + Sparse Branch 双路径"、"精确 KV 分块"和"KV Outer Gather Q 访存优化"三大核心设计,在几乎不损失模型能力的前提下,将百万 Token 上下文的推理效率提升了一个数量级。这不仅是一次工程优化,更是对 Transformer 注意力范式的结构性创新。随着 M3 模型权重的开源,MSA 的技术细节将进一步被社区验证和迭代,有望成为长上下文场景下的新标准。