读到这里,大多数读者应该已经能把 Scaled Dot-Product Attention 的基本流程复述出来: 和 做内积、除以 、过 softmax,再用得到的权重聚合 。问题在于,这台机器每次只给你一组注意力分布。对位置 来说,最终只有一份权重决定它往哪里看、看多少。
这在玩具例子里没有问题,在真实语言里就很快碰到上限。同一个 token 往往同时需要处理句法关系、指代关系、局部邻近关系、语义主题关系。如果把这些判断全部压进同一个 softmax,模型只能在多种关系之间做妥协,而不是并行建模。
Multi-Head Attention 做的事情其实非常直接:把 维表示投到 个独立子空间里,让每个子空间各自形成一份注意力分布,最后再把这些结果拼回去。它看起来像一次简单的切分,但恰恰是这一步,让 Transformer 从单一的相似度度量,变成了同一步内并行处理多种关系的架构。
读完这一篇,你应该能回答:
- 为什么单头 attention 迟早会卡住。
- 为什么多头几乎不增加参数量。
- 不同头到底学到了什么,以及为什么不能把可视化当成因果解释。
- 为什么现代大模型训练时保留多头,推理时却大量使用 GQA 或 MQA。
- 生产代码里为什么一定是大矩阵乘法加 reshape,而不是 for 循环跑 次。
一、为什么一定要多头
1. 单头 attention 的上限在哪里
先把单头形式写清楚。给定输入序列 ,标准 attention 做的是:
这里真正决定模型怎么看世界的是 。对位置 来说,它只有一行 softmax 概率分布,只能在所有候选位置里分出一套权重。这意味着单头 attention 有两个硬约束。
第一,它一次只能表达一种关系。如果当前位置既要看主语和动词之间的句法链路,又要看代词和先行词之间的指代链路,那么这一切都要挤在同一份分布里完成。结果往往不是两种关系都学好,而是两边都被摊薄。
第二,它只能依赖一套相似度度量。 之所以能得到权重,是因为模型假设 和 在同一空间里的点积足以衡量相似度。但句法相似度、位置相似度、主题相似度并不天然属于同一种空间。要求一组 和 同时支撑这些判断,本质上是在逼同一把尺子量多种不同性质的东西。
这就是单头的核心瓶颈:它不缺一次聚合的能力,缺的是同一步里并行处理多种关系的能力。
2. 为什么不靠堆更深的层解决
一个自然反问是:单层只算一种关系也没关系,多堆几层不就行了?
问题在于,深度解决的是逐层组合,宽度解决的是同一步并行。Transformer 每一层的输出都会进入残差流,再交给下一层继续处理。第一层如果已经把一部分信息按某种关系混合了,第二层看到的就是混合后的表示,而不是原始 token 表示。你当然可以让下一层再学另一种关系,但这已经不是同一步并行完成,而是先后改写。
从建模目标上说,多头更像同一层里的多组滤波器,而不是更多层的重复堆叠。CNN 不会指望一个卷积核学完所有局部模式,再靠更深的层把它们拆开;同理,attention 也不应该只有一套相似度度量,然后把所有关系都往后推。
所以多头解决的不是层数不够,而是单层表达过窄的问题。
二、多头到底是怎么工作的
1. 标准定义
Multi-Head Attention 的标准定义是:
关键不在 concat,而在每个头都有自己独立的 、、。同一份输入会被投到 个不同的子空间里,每个子空间各自形成一份 softmax 分布。头和头之间参数不共享,所以模型有机会把不同头训练成不同的关系探测器。
最后的 也不是可有可无的装饰。它的作用是把各个头的输出重新混回统一的残差流,让下一层能够在一个共享空间里继续处理,而不是面对 个互不沟通的孤岛。
2. 参数量为什么几乎不变
很多教程会说多头不怎么增加参数量,但这句话如果不算账,很容易被误解。
把所有头的投影矩阵沿最后一维拼起来,可以得到:
在最常见的设置里,,于是 。这意味着 、、 都退回成 的方阵,再加上一个同样大小的 ,多头 attention 整体上仍然只是 4 个 矩阵。
用 Transformer-base 的常见配置举例:
- 单头大版本:,那么 、、 各是 。
- 8 头版本:每个头的矩阵是 ,一共 8 头,拼起来还是 。
也就是说,多头买的不是更多参数,多头买的是更多独立 softmax 的并行度。
这一点非常关键。一个更大的单头只能给你一份更精细的分布,但还是只有一份分布;多头给你的是 份独立分布,它们可以同时盯住不同关系。
3. 一个最小数值例子
为了把直觉落地,考虑一个极小的例子:,,因此 ,序列长度 。输入设为:
再取最简单的投影:。
这时第 1 个头只看前两维,第 2 个头只看后两维。对同一个 token 来说,这两个头看到的是不同的几何结构。第 1 个头里,第三个 token 同时和前两个 token 有相似性;第 2 个头里,第三个 token 恰好变成零向量,对谁都不特别相似。
如果把第 1 个头的打分写出来,有:
而第 2 个头里,第三行会全部变成 0,softmax 之后就是均匀分布。于是同一个 query 在两个头里得到的注意力模式完全不同:一个头把权重集中到和自己最相近的位置,另一个头因为分辨不出差异,只能平均分配。
这就是多头和单头最本质的差别:多头不是把一个大空间切碎而已,而是给每个子空间独立保留一份 softmax 表达能力。
三、不同的头到底学到了什么
1. BERT 里最常见的四类头
BERT 火起来之后,研究者第一次系统地把多头逐个可视化。Clark 等人的分析里,最常见的头大致可以分成四类。
第一类是位置型。它们几乎只看相邻 token,或者只看自己,像是在做局部 n-gram 聚合。
第二类是锚点型。它们把大量权重给 [CLS]、[SEP]、句号,或者序列开头的若干位置。后来这类模式在长上下文推理里演化成了 attention sink 的重要现象。
第三类是句法型。某些头会稳定地把注意力放到主语对应的动词、介词对应的宾语、修饰语对应的中心词上。模型从来没被显式教过依存语法,但它会自发学出这类结构。
第四类是指代型。它们更稀有,通常出现在中后层,用来追踪 pronoun 和先行词之间的关系。
这些结果至少说明一件事:多头并不是训练出很多完全一样的副本。它们确实会分工,而且分工经常与我们关心的语言结构对应。
2. 可视化很有用,但不是因果解释
看到这里很容易走到另一个极端:把某个好看的注意力图直接当成模型解释。
这一步需要非常克制。Jain 与 Wallace 的结论非常明确:注意力分布可以和某种解释相一致,但不能直接等同于模型的因果机制。因为最终输出不仅取决于注意力权重,还取决于被加权的 本身,以及更早层已经写进残差流里的信息。
所以更稳妥的理解是:
- 可视化适合生成假设。它能告诉你某个头看起来像句法头、像 sink 头、像位置头。
- 消融和干预才更接近验证。把头置零、替换输出、观察性能下降,才更能说明这个头是不是在承担关键功能。
换句话说,注意力图能帮你看见模式,但不能替你完成归因。
3. 跨层分工与头剪枝
如果把视角从单层拉到多层,现象会更有意思。Tenney 等人的 probing 结果显示,BERT 的浅层更接近词法和局部邻近特征,中层更偏句法,深层更偏语义和篇章。这意味着多头不只是横向并行,也在纵向上形成了层级分工。
另一方面,Michel 和 Voita 的剪头实验也说明:并不是每个头都同等重要。很多头可以被单独剪掉而几乎不掉点,但也有少数头一旦剪掉,性能会明显下滑。这说明多头内部既有专责头,也有冗余头。
这对工程的启发非常直接:训练阶段保留较多头,有利于模型探索不同关系;部署阶段则可以把部分冗余结构压缩掉,于是才有了后来的 GQA、MQA 和各种头剪枝方案。
四、从 MHA 到 GQA:工程上的现实约束
1. 头数怎么选
原始 Transformer 的经验其实已经给出了很强的约束:头数不是越多越好,而是要和每头维度一起看。
| 典型结论 | ||
|---|---|---|
| 1 | 512 | 表达力不足,单一分布太受限 |
| 4 | 128 | 明显改善 |
| 8 | 64 | 经典甜点区间 |
| 16 | 32 | 开始变窄 |
| 32 | 16 | 每头维度过小,效果回落 |
后来大模型的配置大体沿着这个经验走:
| 模型 | |||
|---|---|---|---|
| Transformer-base | 512 | 8 | 64 |
| BERT-base | 768 | 12 | 64 |
| BERT-large | 1024 | 16 | 64 |
| GPT-3 175B | 12288 | 96 | 128 |
| LLaMA-2 7B | 4096 | 32 | 128 |
| LLaMA-2 70B | 8192 | 64 | 128 |
最稳定的经验不是头数本身,而是每头维度通常锁在 64 或 128。头太少,关系不够并行;头太多,单头维度又太瘦,连基本的相似度判断都做不扎实。
2. 为什么推理端开始大量砍头
训练时多头是优势,推理时多头却很快变成负担,问题集中在 KV cache。
标准 MHA 里,每个头都有自己的一份 和 。当上下文很长时,这部分缓存会迅速吃光显存。于是工程上出现了两条典型路线:
- MQA:所有 query 头共享同一份 和 ,KV cache 最省,但表达力损失更明显。
- GQA:把 query 头分组,每组共享一份 和 ,在质量和速度之间取折中。
| 变体 | 头数 | 头数 | KV cache | 常见取舍 |
|---|---|---|---|---|
| MHA | 最大 | 训练最好,推理最慢 | ||
| GQA | 中等 | 质量接近 MHA,推理显著更快 | ||
| MQA | 1 | 最小 | 最省显存,但更容易掉点 |
这也是为什么现代大模型常常呈现一个看上去矛盾的趋势:训练时保留较多 query 头,推理时尽量共享 。
3. 训练稳定性的几个注意点
多头本身不神秘,但大模型里它会和训练稳定性强耦合,最常见的注意点有三个。
第一,pre-LN 比 post-LN 更稳。深层模型中,attention 输出会不断写回残差流,post-LN 更容易让梯度方差沿层数积累,pre-LN 在 GPT、LLaMA 这类大模型里已经几乎成为默认选择。
第二,训练前期不同头往往都很像。softmax 输入接近零时,各头的分布都接近均匀,分工是在训练中后期逐渐拉开的。不要拿训练早期的注意力图去解释模型行为。
第三, 的初始化值得认真对待。GPT-2 之后很常见的做法,是按层数缩小 的初始方差,减少 attention 输出反复写回残差流时的方差放大。这不是多头独有的数学性质,但它直接影响多头模块在深层网络里的稳定性。
4. 自注意力和交叉注意力有什么不同
多头机制不仅用于 self-attention,也同样用于 cross-attention。形式上两者完全一样,差别只在来源:
- self-attention 里,、、 都来自同一份输入。
- cross-attention 里, 来自 decoder 当前状态, 和 来自 encoder 输出。
从多头的角度看,变化不在公式,而在任务含义上。self-attention 更像序列内部关系建模,cross-attention 更像目标序列对源序列做可寻址检索。很多翻译和多模态模型里的对齐能力,靠的正是 cross-attention 中不同头的分工。
五、工程实现:一次大矩阵乘法加 reshape
1. 为什么生产代码不是 for 循环
概念上,多头好像就是把 attention 跑 次,然后把结果拼起来。但生产代码从不会真的写一个 for 循环。
原因很简单:GPU 喜欢一次大矩阵乘法,不喜欢很多次小矩阵乘法。真正高效的实现会先用一到三次大 GEMM 一次性算出全部头的 、、,然后 reshape 成 的形状,再把头维度当成 batched matmul 的一个批次维来统一处理。
也就是说,工程实现的骨架其实是:
# X: (B, N, D)
qkv = X @ W_qkv
q, k, v = split_and_reshape(qkv) # (B, h, N, d_k)
scores = (q @ k.transpose(-2, -1)) / sqrt(d_k)
attn = softmax(scores, dim=-1)
out = attn @ v
out = merge_heads(out) @ W_o
核心思想不是把一个头复制很多次,而是把所有头并进同一套张量运算里。
2. 一份完整的 PyTorch 实现
下面是一份简洁但已经接近生产习惯的 PyTorch 写法:
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import sqrt
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % num_heads == 0, 'd_model 必须能被 num_heads 整除'
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
qkv = self.W_qkv(x)
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.d_k)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
scores = torch.matmul(q, k.transpose(-2, -1)) / sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
out = self.W_o(out)
return out
这段实现里有几个值得特别注意的点。
- 、、 被合成了一个 ,这是为了减少 GEMM 次数。
view和permute的顺序不能错,错了通常不会立刻报错,但模型会学不起来。contiguous()不是多余的,它是在张量转置之后为后续 reshape 和 matmul 保底。- 真正上 GPU 跑大模型时,通常会直接调用
F.scaled_dot_product_attention或者底层 fused kernel,而不是自己手写 softmax。
3. 最容易踩的几个坑
多头实现里最常见的坑,基本都不是理论错误,而是张量细节错误。
第一, 和头数不整除。这是最简单也最常见的 bug。
第二,reshape 顺序错。把 写成 ,代码可能照样能跑,但 token 维和 head 维已经被弄乱。
第三,mask 形状或 dtype 不对。实践里最好显式把 mask 写成带 head 维的 bool tensor,不要依赖隐式 broadcast。
第四,不要轻易删掉 。concat 之后虽然已经回到 维,但没有 ,各头之间就失去了重新混合和重新写入残差流的机会。
六、把答案收回到核心问题
如果把整篇内容压缩成一句话,那么 Multi-Head Attention 的作用就是:把一次 attention 从单一 softmax 升级成多组并行 softmax,让模型在同一步里同时建模多种关系。
它真正厉害的地方不在于公式有多复杂,而在于设计非常节制:参数量基本不变,计算结构仍然适合大矩阵乘法,表达力却从一组关系扩展成了一组子空间里的并行关系。后续从 BERT、GPT 到 LLaMA,再到 GQA、MQA 和 FlashAttention,本质上都仍然在围绕这个设计继续打磨。
关键概念回顾
- 多头的本质不是把维度切碎,而是给不同子空间各自保留一份独立的 softmax 分布。
- 在 的标准设置下,多头几乎不比单头多参数;它换来的主要是并行建模不同关系的能力。
- 不同头确实会分工,但注意力图只能作为线索,不能直接当成因果解释。
- 训练喜欢保留较多独立头,推理则更关心 KV cache,所以现代大模型才会大量使用 GQA 和 MQA。
- 真正高效的实现一定是大矩阵乘法加 reshape,而不是 for 循环跑 次 attention。
常见误解
- 误解一:头越多越好。错。头数必须和每头维度一起看,单头太瘦会直接掉表达力。
- 误解二:多头比单头多很多参数。错。标准配置下参数量几乎等价。
- 误解三:一个漂亮的注意力图就等于模型学会了句法。错。可视化只能给出相关性线索,不是因果证明。
- 误解四:把同一个 attention 跑 次再平均就是多头。错。多头的关键是每个头有自己独立的投影矩阵。
- 误解五:推理时继续保留完整 MHA 一定最好。错。部署场景里,GQA 和 MQA 往往是更合理的工程折中。
下一步
- 想理解 decoder 为什么不能看未来:去看 17. Causal Mask。
- 想接着看 attention 的复杂度和长上下文瓶颈:18. 注意力的复杂度问题。
- 想回到总图看 encoder、decoder 和 FFN 怎么拼起来:20. Transformer 整体架构。
参考文献
- Vaswani A., Shazeer N., Parmar N., Uszkoreit J., Jones L., Gomez A. N., Kaiser L., Polosukhin I. Attention Is All You Need. NeurIPS 2017.
- Clark K., Khandelwal U., Levy O., Manning C. D. What Does BERT Look At? An Analysis of BERT's Attention. EMNLP 2019.
- Voita E., Talbot D., Moiseev F., Sennrich R., Titov I. Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting, the Rest Can Be Pruned. ACL 2019.
- Michel P., Levy O., Neubig G. Are Sixteen Heads Really Better than One? NeurIPS 2019.
- Jain S., Wallace B. C. Attention is not Explanation. NAACL 2019.
- Tenney I., Das D., Pavlick E. BERT Rediscovers the Classical NLP Pipeline. ACL 2019.
- Shazeer N. Fast Transformer Decoding: One Write-Head is All You Need. 2019.
- Ainslie J., Lee-Thorp J., de Jong M., Zemlyanskiy Y., Lebrón F., Sanghai S. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP 2023.
- Xiao G., Tian Y., Chen B., Han S., Lewis M. Efficient Streaming Language Models with Attention Sinks. ICLR 2024.
- Xiong R., Yang Y., He D., Zheng K., Zheng S., Xing C., Zhang H., Lan Y., Wang L., Liu T. On Layer Normalization in the Transformer Architecture. ICML 2020.
← 上一篇:15. Scaled Dot-Product | 下一篇:17. Causal Mask →