第04章:Multi-Head Attention——八个头,八个视角,还是八份低秩分解?
论文链接:Attention Is All You Need (Vaswani et al., NIPS 2017) 本章对应:Section 3.2.2, Section 3.2.3, Table 3 row (A)
由于掘金平台限制,完整版(含架构图和数学公式)请访问其他平台同名账号或: 📖 GitHub: github.com/Yunzenn/blo…
核心困惑
为什么要用多个head?一个head不够吗?
第03章讲了Scaled Dot-Product Attention的数学原理。但原论文不是直接用一个Attention,而是用了8个并行的Attention,叫Multi-Head Attention。
这不是简单的”多跑几次”。8个head有各自的投影矩阵WQ,WK,WV,最后再拼接起来。这个设计背后的动机是什么?
完整的Multi-Head Attention公式:
MultiHead(Q,K,V)=Concat(head1,...,headh)WO
where headi=Attention(QWiQ,KWiK,VWiV)
为什么是8个head而不是1个或16个?这个选择有数学依据吗?
前置知识补给站
1. 矩阵的秩与低秩分解
矩阵的秩:矩阵的秩是其线性独立的行(或列)的最大数量。
低秩分解:将一个大矩阵分解为两个小矩阵的乘积。
Am×n=Bm×r⋅Cr×n
其中r<min(m,n),r叫做秩。
为什么要低秩分解:
- 减少参数量:mn个参数变成mr+rn个
- 当r≪min(m,n)时,参数量大幅减少
2. 表示子空间
在d维空间中,一个k维子空间是所有可以用k个基向量线性组合表示的向量的集合。
直观理解:
- 1维子空间:一条直线
- 2维子空间:一个平面
- 维子空间:一个维”超平面”
为什么需要多个子空间:不同的子空间可以捕捉不同的特征。
3. 并行计算的优势
在GPU上,多个小矩阵乘法可以并行计算,总时间接近单个大矩阵乘法。
关键:个head的计算可以同时进行,不需要等待。
论文精读:Multi-Head Attention的设计
原论文的动机
Section 3.2.2:
“Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this.”
翻译成人话:
- 单个head会”平均”所有信息,丢失细节
- 多个head可以关注不同的”表示子空间”
- 每个head学习不同的模式
但原论文没有详细解释”表示子空间”是什么意思。我们来深入分析。
Multi-Head Attention的完整公式
参数矩阵:
- :Query投影矩阵
- :Key投影矩阵
- :Value投影矩阵
- :输出投影矩阵
维度设置(原论文):
- :模型维度
- :head数量
- :每个head的维度
为什么每个head的维度是?
关键洞察:这样设计使得Multi-Head Attention的总计算量与Single-Head Attention相同。
Single-Head Attention的计算量:
在Single-Head Attention中,因为没有降维投影,Key和Value的维度就是。因此:
- :
- Softmax:
- 乘以:
- 总计:
Multi-Head Attention的计算量(每个head):
- :
- :
- Softmax:
- 乘以:
- 单个head:
个head的总计算量:
如果:
结论:Multi-Head的总计算量与Single-Head相同!
这是一个精妙的设计:用相同的计算量,换取了多个”视角”。
第一性原理推导:Multi-Head作为低秩分解
视角1:表示子空间
每个head学习一个维的子空间。
Single-Head Attention:
- 在维空间中直接计算attention
- 所有信息混在一起
Multi-Head Attention:
- 每个head在维子空间中计算attention
- 个子空间可以捕捉不同的模式
数学表达:
定义了第个子空间。
视角2:低秩分解
Multi-Head Attention可以理解为一种低秩分解。
Single-Head Attention:
其中是一个的attention矩阵,秩的上限是。
Multi-Head Attention的低秩分解视角:
将分解为个小矩阵的组合:
每个的秩上限是(因为是矩阵)。
个这样的矩阵组合后,总有效秩上限为。
关键洞察:
- 在计算量相同的前提下,Multi-Head通过低秩分解让模型能够并行地从多个低维子空间捕捉信息
- 每个head专注于一个维子空间,个head组合起来覆盖整个维空间
视角3:集成学习
Multi-Head Attention类似于集成学习中的”多个弱学习器”。
- 每个head是一个”弱学习器”(只看维)
- 个head组合起来形成”强学习器”
类比:
- Random Forest:多棵决策树投票
- Multi-Head Attention:多个attention head拼接
消融实验解读:Table 3 row (A)
原论文Table 3 row (A):
| h | d_k | d_v | PPL (dev) | BLEU (dev) | 参数量 |
|---|---|---|---|---|---|
| 1 | 512 | 512 | 5.29 | 24.9 | 65M |
| 4 | 128 | 128 | 5.00 | 25.5 | 65M |
| 8 (base) | 64 | 64 | 4.92 | 25.8 | 65M |
| 16 | 32 | 32 | 4.91 | 25.8 | 65M |
| 32 | 16 | 16 | 5.01 | 25.4 | 65M |
关键观察:
- 最差:PPL 5.29,BLEU 24.9
-
单个head无法捕捉多样化的模式
-
验证了Multi-Head的必要性
- 和效果相当:PPL 4.92 vs 4.91,BLEU 25.8 vs 25.8
-
8个head已经足够
-
继续增加head收益递减
- 效果下降:PPL 5.01,BLEU 25.4
-
每个head只有16维,表示能力太弱
-
过度分割反而有害
- 参数量相同:所有配置都是65M参数
-
因为保持不变
-
这是一个公平的对比
结论:
- 是一个”甜蜜点”:既有足够的多样性,又不会过度分割
- 是一个合理的子空间维度
三种Attention的Multi-Head实现对比
在第02章我们讲了三种Attention的Q/K/V来源不同。现在我们看看它们在Multi-Head中的实现。
对比表格
| Attention类型 | Q投影 | K投影 | V投影 | 每个head的操作 |
|---|---|---|---|---|
| Encoder Self-Attention | X W_i^Q | X W_i^K | X W_i^V | \text{Attention}(X W_i^Q, X W_i^K, X W_i^V) |
| Decoder Masked Self-Attention | Y W_i^Q | Y W_i^K | Y W_i^V | \text{Attention}(Y W_i^Q, Y W_i^K, Y W_i^V, \text{mask}) |
| Decoder Cross-Attention | Y W_i^Q | Z W_i^K | Z W_i^V | \text{Attention}(Y W_i^Q, Z W_i^K, Z W_i^V) |
关键区别:
- Self-Attention:Q/K/V都投影自同一个输入(或)
- Cross-Attention:Q投影自Decoder(),K/V投影自Encoder()
统一性:
- 三种Attention都用同一个Multi-Head框架
- 只是输入来源不同
Multi-Head Attention的完整计算流程
添加图片注释,不超过 140 字(可选)
关键步骤:
- 并行投影:每个head独立计算
- 并行Attention:每个head独立计算Scaled Dot-Product Attention
- 拼接:将个head的输出拼接成一个大向量
- 输出投影:用将拼接后的向量投影回维
Head剪枝:哪些head可以被安全移除?
后续研究(Michel et al., “Are Sixteen Heads Really Better than One?”, NeurIPS 2019)发现:很多head是冗余的。
Table 3 row (A)已经暗示了冗余:和效果几乎相同(BLEU都是25.8),说明多出来的8个head没有学到新东西。这为后续的head剪枝研究提供了实验依据。
实验发现
- 大部分head可以被移除:
-
在某些任务上,移除50%的head,性能几乎不变
-
甚至有些层只需要1个head
- 不同head学到了不同的模式:
-
有些head关注局部信息(相邻位置)
-
有些head关注长距离依赖
-
有些head关注特定的语法结构(如主谓关系)
- head的重要性因层而异:
-
浅层:head更冗余
-
深层:head更专业化
为什么会有冗余?
可能的原因:
- 过参数化:模型参数量远大于任务需要
- 训练动态:某些head在训练早期有用,后期被其他head替代
- 集成效应:多个head提供了”保险”,即使某些head失效,其他head仍能工作
工程启示:
- 可以用head剪枝来压缩模型
- 但原论文的是一个保守的选择,确保了鲁棒性
2026年的批判性视角
1. 原论文没有深入分析head学到了什么
原论文的局限:
- 只给了消融实验(Table 3 row A)
- 没有可视化不同head的attention pattern
- 没有解释为什么是最优的
后续研究的发现(Voita et al., “Analyzing Multi-Head Self-Attention”, ACL 2019):
- 不同head确实学到了不同的模式
- 但很多head是冗余的
- 可以用正则化来鼓励head的多样性
2. 是经验选择还是理论最优?
原论文的选择:
可能的理论依据:
- 足够大,可以表示复杂的模式
- 足够多,可以捕捉多样性
- 但没有严格的数学证明
后续模型的选择:
- GPT-2:()
- GPT-3:()
- 趋势:更大的模型用更多的head
3. Multi-Head vs Multi-Query Attention (MQA)
Multi-Head Attention的问题:
- 每个head都有独立的K和V
- 推理时需要缓存份KV Cache
- 内存占用大
Multi-Query Attention (MQA)(Shazeer, 2019):
- 所有head共享同一个K和V
- 只有Q是多头的
- KV Cache减少到原来的
Grouped-Query Attention (GQA)(Ainslie et al., 2023):
- MQA和Multi-Head的折中
- 将个head分成组,每组共享K和V
- LLaMA 2、Mistral等模型使用
4. Cross-Attention的K/V都来自Encoder
原论文的设计:Cross-Attention的K和V都来自Encoder输出。
问题:为什么不能K来自Encoder,V来自Decoder?
答案:
- Attention的机制是:用Q查询K,得到”应该关注Encoder的哪个位置”,然后用这个权重去V中提取信息
- K和V必须来自同一个地方,因为:
-
Q-K的点积计算出”位置应该关注Encoder的哪个位置“
-
这个权重用于加权求和V,提取”位置的信息”
-
如果V来自Decoder而K来自Encoder,那么Q-K计算出的是”应该关注Encoder的位置“,但V提取的却是”Decoder的位置的信息”——两者对不上
- 如果V来自Decoder,Cross-Attention就变成了另一种形式的Self-Attention,失去了”关注输入”的能力
这是第02章面试题第4题的答案。
面试追问清单
⭐ 基础必会
- 为什么要用Multi-Head Attention而不是Single-Head?
- 提示:表示子空间、多样性
- Multi-Head Attention的计算复杂度是Single-Head的几倍?
- 提示:相同(因为)
- 原论文为什么选择?
- 提示:Table 3 row (A)的消融实验
⭐⭐ 进阶思考
- 证明:Multi-Head Attention的总计算量与Single-Head Attention相同。
- 提示:展开矩阵乘法的复杂度,利用
- 如果但(而不是512),效果会怎样?
- 提示:参数量减少,表示能力下降
- 为什么Table 3显示效果反而下降?
- 提示:每个head只有16维,表示能力太弱
⭐⭐⭐ 专家领域
- Multi-Head Attention可以看作是什么数学结构的低秩分解?
- 提示:全维度attention矩阵的低秩近似
- 如何设计一个实验来验证”不同head学到了不同的模式”?
- 提示:可视化attention pattern、head剪枝实验、正则化鼓励多样性
- Multi-Query Attention (MQA)和Multi-Head Attention有什么区别?为什么MQA可以减少KV Cache?
- 提示:MQA的所有head共享K和V,只有Q是多头的
下一章预告:第05章将深入拆解残差连接与Layer Normalization,回答”Pre-LN和Post-LN有什么区别?为什么GPT用Pre-LN?”
论文原文传送门:
- Transformer原论文:proceedings.neurips.cc/paper_files…
- 官方代码:github.com/tensorflow/…