Multi-Head Attention:为什么要分多个头

22 阅读2分钟

读到这里,大多数读者应该已经能把 Scaled Dot-Product Attention 的基本流程复述出来:QQKK 做内积、除以 dk\sqrt{d_k}、过 softmax,再用得到的权重聚合 VV。问题在于,这台机器每次只给你一组注意力分布。对位置 ii 来说,最终只有一份权重决定它往哪里看、看多少。

这在玩具例子里没有问题,在真实语言里就很快碰到上限。同一个 token 往往同时需要处理句法关系、指代关系、局部邻近关系、语义主题关系。如果把这些判断全部压进同一个 softmax,模型只能在多种关系之间做妥协,而不是并行建模。

Multi-Head Attention 做的事情其实非常直接:把 dmodeld_{model} 维表示投到 hh 个独立子空间里,让每个子空间各自形成一份注意力分布,最后再把这些结果拼回去。它看起来像一次简单的切分,但恰恰是这一步,让 Transformer 从单一的相似度度量,变成了同一步内并行处理多种关系的架构。

读完这一篇,你应该能回答:

  • 为什么单头 attention 迟早会卡住。
  • 为什么多头几乎不增加参数量。
  • 不同头到底学到了什么,以及为什么不能把可视化当成因果解释。
  • 为什么现代大模型训练时保留多头,推理时却大量使用 GQA 或 MQA。
  • 生产代码里为什么一定是大矩阵乘法加 reshape,而不是 for 循环跑 hh 次。

原文链接

一、为什么一定要多头

1. 单头 attention 的上限在哪里

先把单头形式写清楚。给定输入序列 XRn×dX \in \mathbb{R}^{n \times d},标准 attention 做的是:

Q=XWQ,K=XWK,V=XWVA=softmax(QKTdk),Z=AV\begin{aligned} Q &= XW^Q,\quad K = XW^K,\quad V = XW^V \\ A &= \operatorname{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right),\quad Z = AV \end{aligned}

这里真正决定模型怎么看世界的是 AA。对位置 ii 来说,它只有一行 softmax 概率分布,只能在所有候选位置里分出一套权重。这意味着单头 attention 有两个硬约束。

第一,它一次只能表达一种关系。如果当前位置既要看主语和动词之间的句法链路,又要看代词和先行词之间的指代链路,那么这一切都要挤在同一份分布里完成。结果往往不是两种关系都学好,而是两边都被摊薄。

第二,它只能依赖一套相似度度量。QKTQK^T 之所以能得到权重,是因为模型假设 QQKK 在同一空间里的点积足以衡量相似度。但句法相似度、位置相似度、主题相似度并不天然属于同一种空间。要求一组 WQW^QWKW^K 同时支撑这些判断,本质上是在逼同一把尺子量多种不同性质的东西。

这就是单头的核心瓶颈:它不缺一次聚合的能力,缺的是同一步里并行处理多种关系的能力。

2. 为什么不靠堆更深的层解决

一个自然反问是:单层只算一种关系也没关系,多堆几层不就行了?

问题在于,深度解决的是逐层组合,宽度解决的是同一步并行。Transformer 每一层的输出都会进入残差流,再交给下一层继续处理。第一层如果已经把一部分信息按某种关系混合了,第二层看到的就是混合后的表示,而不是原始 token 表示。你当然可以让下一层再学另一种关系,但这已经不是同一步并行完成,而是先后改写。

从建模目标上说,多头更像同一层里的多组滤波器,而不是更多层的重复堆叠。CNN 不会指望一个卷积核学完所有局部模式,再靠更深的层把它们拆开;同理,attention 也不应该只有一套相似度度量,然后把所有关系都往后推。

所以多头解决的不是层数不够,而是单层表达过窄的问题。


二、多头到底是怎么工作的

1. 标准定义

Multi-Head Attention 的标准定义是:

MultiHead(Q,K,V)=Concat(head1,,headh)WOheadi=Attention(QWiQ,KWiK,VWiV)\begin{aligned} \operatorname{MultiHead}(Q, K, V) &= \operatorname{Concat}(\operatorname{head}_1, \ldots, \operatorname{head}_h)W^O \\ \operatorname{head}_i &= \operatorname{Attention}(QW_i^Q, KW_i^K, VW_i^V) \end{aligned}

关键不在 concat,而在每个头都有自己独立的 WiQW_i^QWiKW_i^KWiVW_i^V。同一份输入会被投到 hh 个不同的子空间里,每个子空间各自形成一份 softmax 分布。头和头之间参数不共享,所以模型有机会把不同头训练成不同的关系探测器。

最后的 WOW^O 也不是可有可无的装饰。它的作用是把各个头的输出重新混回统一的残差流,让下一层能够在一个共享空间里继续处理,而不是面对 hh 个互不沟通的孤岛。

Multi-Head Attention 并行结构转存失败,建议直接上传图片文件

2. 参数量为什么几乎不变

很多教程会说多头不怎么增加参数量,但这句话如果不算账,很容易被误解。

把所有头的投影矩阵沿最后一维拼起来,可以得到:

WfullQRdmodel×hdkWfullKRdmodel×hdkWfullVRdmodel×hdv\begin{aligned} W_{\mathrm{full}}^Q &\in \mathbb{R}^{d_{model} \times h d_k} \\ W_{\mathrm{full}}^K &\in \mathbb{R}^{d_{model} \times h d_k} \\ W_{\mathrm{full}}^V &\in \mathbb{R}^{d_{model} \times h d_v} \end{aligned}

在最常见的设置里,dk=dv=dmodel/hd_k = d_v = d_{model} / h,于是 hdk=dmodelh d_k = d_{model}。这意味着 WfullQW_{\mathrm{full}}^QWfullKW_{\mathrm{full}}^KWfullVW_{\mathrm{full}}^V 都退回成 dmodel×dmodeld_{model} \times d_{model} 的方阵,再加上一个同样大小的 WOW^O,多头 attention 整体上仍然只是 4 个 dmodel×dmodeld_{model} \times d_{model} 矩阵。

用 Transformer-base 的常见配置举例:

  • 单头大版本:dmodel=512d_{model} = 512,那么 WQW^QWKW^KWVW^V 各是 512×512512 \times 512
  • 8 头版本:每个头的矩阵是 512×64512 \times 64,一共 8 头,拼起来还是 512×512512 \times 512

也就是说,多头买的不是更多参数,多头买的是更多独立 softmax 的并行度。

单头大维度 vs 多头小维度的参数等价转存失败,建议直接上传图片文件

这一点非常关键。一个更大的单头只能给你一份更精细的分布,但还是只有一份分布;多头给你的是 hh 份独立分布,它们可以同时盯住不同关系。

3. 一个最小数值例子

为了把直觉落地,考虑一个极小的例子:dmodel=4d_{model} = 4h=2h = 2,因此 dk=dv=2d_k = d_v = 2,序列长度 n=3n = 3。输入设为:

X=(101001011100)X = \begin{pmatrix} 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 \\ 1 & 1 & 0 & 0 \end{pmatrix}

再取最简单的投影:WQ=WK=WV=IW^Q = W^K = W^V = I

这时第 1 个头只看前两维,第 2 个头只看后两维。对同一个 token 来说,这两个头看到的是不同的几何结构。第 1 个头里,第三个 token 同时和前两个 token 有相似性;第 2 个头里,第三个 token 恰好变成零向量,对谁都不特别相似。

如果把第 1 个头的打分写出来,有:

scores1=Q1K1T2=(1/201/201/21/21/21/22/2)\operatorname{scores}_1 = \frac{Q_1K_1^T}{\sqrt{2}} = \begin{pmatrix} 1/\sqrt{2} & 0 & 1/\sqrt{2} \\ 0 & 1/\sqrt{2} & 1/\sqrt{2} \\ 1/\sqrt{2} & 1/\sqrt{2} & 2/\sqrt{2} \end{pmatrix}

而第 2 个头里,第三行会全部变成 0,softmax 之后就是均匀分布。于是同一个 query 在两个头里得到的注意力模式完全不同:一个头把权重集中到和自己最相近的位置,另一个头因为分辨不出差异,只能平均分配。

这就是多头和单头最本质的差别:多头不是把一个大空间切碎而已,而是给每个子空间独立保留一份 softmax 表达能力。


三、不同的头到底学到了什么

1. BERT 里最常见的四类头

BERT 火起来之后,研究者第一次系统地把多头逐个可视化。Clark 等人的分析里,最常见的头大致可以分成四类。

第一类是位置型。它们几乎只看相邻 token,或者只看自己,像是在做局部 n-gram 聚合。

第二类是锚点型。它们把大量权重给 [CLS]、[SEP]、句号,或者序列开头的若干位置。后来这类模式在长上下文推理里演化成了 attention sink 的重要现象。

第三类是句法型。某些头会稳定地把注意力放到主语对应的动词、介词对应的宾语、修饰语对应的中心词上。模型从来没被显式教过依存语法,但它会自发学出这类结构。

第四类是指代型。它们更稀有,通常出现在中后层,用来追踪 pronoun 和先行词之间的关系。

不同头学到的注意力模式转存失败,建议直接上传图片文件

这些结果至少说明一件事:多头并不是训练出很多完全一样的副本。它们确实会分工,而且分工经常与我们关心的语言结构对应。

2. 可视化很有用,但不是因果解释

看到这里很容易走到另一个极端:把某个好看的注意力图直接当成模型解释。

这一步需要非常克制。Jain 与 Wallace 的结论非常明确:注意力分布可以和某种解释相一致,但不能直接等同于模型的因果机制。因为最终输出不仅取决于注意力权重,还取决于被加权的 VV 本身,以及更早层已经写进残差流里的信息。

所以更稳妥的理解是:

  • 可视化适合生成假设。它能告诉你某个头看起来像句法头、像 sink 头、像位置头。
  • 消融和干预才更接近验证。把头置零、替换输出、观察性能下降,才更能说明这个头是不是在承担关键功能。

换句话说,注意力图能帮你看见模式,但不能替你完成归因。

3. 跨层分工与头剪枝

如果把视角从单层拉到多层,现象会更有意思。Tenney 等人的 probing 结果显示,BERT 的浅层更接近词法和局部邻近特征,中层更偏句法,深层更偏语义和篇章。这意味着多头不只是横向并行,也在纵向上形成了层级分工。

另一方面,Michel 和 Voita 的剪头实验也说明:并不是每个头都同等重要。很多头可以被单独剪掉而几乎不掉点,但也有少数头一旦剪掉,性能会明显下滑。这说明多头内部既有专责头,也有冗余头。

这对工程的启发非常直接:训练阶段保留较多头,有利于模型探索不同关系;部署阶段则可以把部分冗余结构压缩掉,于是才有了后来的 GQA、MQA 和各种头剪枝方案。


四、从 MHA 到 GQA:工程上的现实约束

1. 头数怎么选

原始 Transformer 的经验其实已经给出了很强的约束:头数不是越多越好,而是要和每头维度一起看。

hhdkd_k典型结论
1512表达力不足,单一分布太受限
4128明显改善
864经典甜点区间
1632开始变窄
3216每头维度过小,效果回落

后来大模型的配置大体沿着这个经验走:

模型dmodeld_{model}hhdkd_k
Transformer-base512864
BERT-base7681264
BERT-large10241664
GPT-3 175B1228896128
LLaMA-2 7B409632128
LLaMA-2 70B819264128

最稳定的经验不是头数本身,而是每头维度通常锁在 64 或 128。头太少,关系不够并行;头太多,单头维度又太瘦,连基本的相似度判断都做不扎实。

2. 为什么推理端开始大量砍头

训练时多头是优势,推理时多头却很快变成负担,问题集中在 KV cache。

标准 MHA 里,每个头都有自己的一份 KKVV。当上下文很长时,这部分缓存会迅速吃光显存。于是工程上出现了两条典型路线:

  • MQA:所有 query 头共享同一份 KKVV,KV cache 最省,但表达力损失更明显。
  • GQA:把 query 头分组,每组共享一份 KKVV,在质量和速度之间取折中。
变体QQ 头数K/VK/V 头数KV cache常见取舍
MHAhhhh最大训练最好,推理最慢
GQAhhgg中等质量接近 MHA,推理显著更快
MQAhh1最小最省显存,但更容易掉点

这也是为什么现代大模型常常呈现一个看上去矛盾的趋势:训练时保留较多 query 头,推理时尽量共享 K/VK/V

3. 训练稳定性的几个注意点

多头本身不神秘,但大模型里它会和训练稳定性强耦合,最常见的注意点有三个。

第一,pre-LN 比 post-LN 更稳。深层模型中,attention 输出会不断写回残差流,post-LN 更容易让梯度方差沿层数积累,pre-LN 在 GPT、LLaMA 这类大模型里已经几乎成为默认选择。

第二,训练前期不同头往往都很像。softmax 输入接近零时,各头的分布都接近均匀,分工是在训练中后期逐渐拉开的。不要拿训练早期的注意力图去解释模型行为。

第三,WOW^O 的初始化值得认真对待。GPT-2 之后很常见的做法,是按层数缩小 WOW^O 的初始方差,减少 attention 输出反复写回残差流时的方差放大。这不是多头独有的数学性质,但它直接影响多头模块在深层网络里的稳定性。

4. 自注意力和交叉注意力有什么不同

多头机制不仅用于 self-attention,也同样用于 cross-attention。形式上两者完全一样,差别只在来源:

  • self-attention 里,QQKKVV 都来自同一份输入。
  • cross-attention 里,QQ 来自 decoder 当前状态,KKVV 来自 encoder 输出。

从多头的角度看,变化不在公式,而在任务含义上。self-attention 更像序列内部关系建模,cross-attention 更像目标序列对源序列做可寻址检索。很多翻译和多模态模型里的对齐能力,靠的正是 cross-attention 中不同头的分工。


五、工程实现:一次大矩阵乘法加 reshape

1. 为什么生产代码不是 for 循环

概念上,多头好像就是把 attention 跑 hh 次,然后把结果拼起来。但生产代码从不会真的写一个 for 循环。

原因很简单:GPU 喜欢一次大矩阵乘法,不喜欢很多次小矩阵乘法。真正高效的实现会先用一到三次大 GEMM 一次性算出全部头的 QQKKVV,然后 reshape 成 (B,h,N,dk)(B, h, N, d_k) 的形状,再把头维度当成 batched matmul 的一个批次维来统一处理。

也就是说,工程实现的骨架其实是:

# X: (B, N, D)
qkv = X @ W_qkv
q, k, v = split_and_reshape(qkv)   # (B, h, N, d_k)
scores = (q @ k.transpose(-2, -1)) / sqrt(d_k)
attn = softmax(scores, dim=-1)
out = attn @ v
out = merge_heads(out) @ W_o

核心思想不是把一个头复制很多次,而是把所有头并进同一套张量运算里。

2. 一份完整的 PyTorch 实现

下面是一份简洁但已经接近生产习惯的 PyTorch 写法:

import torch
import torch.nn as nn
import torch.nn.functional as F
from math import sqrt


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0, 'd_model 必须能被 num_heads 整除'
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        qkv = self.W_qkv(x)
        qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.d_k)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        scores = torch.matmul(q, k.transpose(-2, -1)) / sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        out = self.W_o(out)
        return out

这段实现里有几个值得特别注意的点。

  • WQW^QWKW^KWVW^V 被合成了一个 WqkvW_{qkv},这是为了减少 GEMM 次数。
  • viewpermute 的顺序不能错,错了通常不会立刻报错,但模型会学不起来。
  • contiguous() 不是多余的,它是在张量转置之后为后续 reshape 和 matmul 保底。
  • 真正上 GPU 跑大模型时,通常会直接调用 F.scaled_dot_product_attention 或者底层 fused kernel,而不是自己手写 softmax。

3. 最容易踩的几个坑

多头实现里最常见的坑,基本都不是理论错误,而是张量细节错误。

第一,dmodeld_{model} 和头数不整除。这是最简单也最常见的 bug。

第二,reshape 顺序错。把 (B,N,h,dk)(B, N, h, d_k) 写成 (B,h,N,dk)(B, h, N, d_k),代码可能照样能跑,但 token 维和 head 维已经被弄乱。

第三,mask 形状或 dtype 不对。实践里最好显式把 mask 写成带 head 维的 bool tensor,不要依赖隐式 broadcast。

第四,不要轻易删掉 WOW^O。concat 之后虽然已经回到 dmodeld_{model} 维,但没有 WOW^O,各头之间就失去了重新混合和重新写入残差流的机会。


六、把答案收回到核心问题

如果把整篇内容压缩成一句话,那么 Multi-Head Attention 的作用就是:把一次 attention 从单一 softmax 升级成多组并行 softmax,让模型在同一步里同时建模多种关系。

它真正厉害的地方不在于公式有多复杂,而在于设计非常节制:参数量基本不变,计算结构仍然适合大矩阵乘法,表达力却从一组关系扩展成了一组子空间里的并行关系。后续从 BERT、GPT 到 LLaMA,再到 GQA、MQA 和 FlashAttention,本质上都仍然在围绕这个设计继续打磨。


关键概念回顾

  • 多头的本质不是把维度切碎,而是给不同子空间各自保留一份独立的 softmax 分布。
  • dk=dmodel/hd_k = d_{model} / h 的标准设置下,多头几乎不比单头多参数;它换来的主要是并行建模不同关系的能力。
  • 不同头确实会分工,但注意力图只能作为线索,不能直接当成因果解释。
  • 训练喜欢保留较多独立头,推理则更关心 KV cache,所以现代大模型才会大量使用 GQA 和 MQA。
  • 真正高效的实现一定是大矩阵乘法加 reshape,而不是 for 循环跑 hh 次 attention。

常见误解

  • 误解一:头越多越好。错。头数必须和每头维度一起看,单头太瘦会直接掉表达力。
  • 误解二:多头比单头多很多参数。错。标准配置下参数量几乎等价。
  • 误解三:一个漂亮的注意力图就等于模型学会了句法。错。可视化只能给出相关性线索,不是因果证明。
  • 误解四:把同一个 attention 跑 hh 次再平均就是多头。错。多头的关键是每个头有自己独立的投影矩阵。
  • 误解五:推理时继续保留完整 MHA 一定最好。错。部署场景里,GQA 和 MQA 往往是更合理的工程折中。

下一步

参考文献

  • Vaswani A., Shazeer N., Parmar N., Uszkoreit J., Jones L., Gomez A. N., Kaiser L., Polosukhin I. Attention Is All You Need. NeurIPS 2017.
  • Clark K., Khandelwal U., Levy O., Manning C. D. What Does BERT Look At? An Analysis of BERT's Attention. EMNLP 2019.
  • Voita E., Talbot D., Moiseev F., Sennrich R., Titov I. Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting, the Rest Can Be Pruned. ACL 2019.
  • Michel P., Levy O., Neubig G. Are Sixteen Heads Really Better than One? NeurIPS 2019.
  • Jain S., Wallace B. C. Attention is not Explanation. NAACL 2019.
  • Tenney I., Das D., Pavlick E. BERT Rediscovers the Classical NLP Pipeline. ACL 2019.
  • Shazeer N. Fast Transformer Decoding: One Write-Head is All You Need. 2019.
  • Ainslie J., Lee-Thorp J., de Jong M., Zemlyanskiy Y., Lebrón F., Sanghai S. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP 2023.
  • Xiao G., Tian Y., Chen B., Han S., Lewis M. Efficient Streaming Language Models with Attention Sinks. ICLR 2024.
  • Xiong R., Yang Y., He D., Zheng K., Zheng S., Xing C., Zhang H., Lan Y., Wang L., Liu T. On Layer Normalization in the Transformer Architecture. ICML 2020.

← 上一篇:15. Scaled Dot-Product | 下一篇:17. Causal Mask