一文讲清 nn.MultiheadAttention 多头注意力

348 阅读5分钟

我们用一个非常生活化的例子——“课堂小组讨论 + 班主任观察”——来向初学者讲清楚 PyTorch 中 torch.nn.MultiheadAttention 的原理、应用场景和使用方法。

🎯 一句话总结:

MultiheadAttention 就像班主任同时从多个角度(“多头”)观察学生在小组讨论中的发言,判断“谁在回应谁”、“谁更重要”,从而提炼出每个人的“关键表现总结”。


📚 一、用“课堂小组讨论”打比方 —— 形象理解 MultiheadAttention

假设你是一个班主任,正在观察一个 5 人小组(张三、李四、王五、赵六、小美)讨论“如何保护环境”。

你不是只从一个角度看,而是同时派出 4 个助教(= 4 个“头”),分别从不同角度观察:

  • 助教A:关注“谁在回应环保政策”
  • 助教B:关注“谁提出了新点子”
  • 助教C:关注“谁在反驳别人”
  • 助教D:关注“谁总结得最好”

每个助教都会给每个学生打一个“注意力分数”,表示“这个学生在这个角度上,应该关注谁”。

👉 然后,每个学生会根据所有助教的打分,加权汇总其他人的发言内容,形成自己的“更新版发言”。

🧠 这就是 MultiheadAttention 的核心思想:

  • 输入:每个学生的原始发言(= 向量)
  • 输出:每个学生融合了“别人发言重点”后的新发言(= 更新后的向量)
  • “注意力” = 衡量“我应该多关注谁”

⚙️ 二、MultiheadAttention 的数学原理(简化版)

MultiheadAttention 有三个输入(也可以是同一个):

  • Query (Q):我在找什么?→ “我想知道谁值得我关注”
  • Key (K):我是谁?→ “我的发言关键词是...”
  • Value (V):我有什么内容?→ “我的发言内容是...”

过程分三步:

  1. 计算注意力分数
    Attention(Q, K, V) = softmax(Q @ K^T / √d_k) @ V
    → 简单说:Q 和 K 做点积,算“匹配度”,然后用 softmax 变成权重,去加权 V

  2. 多头机制
    把 Q、K、V 分成 num_heads 份(如4头),每头独立计算注意力,最后拼接起来

  3. 线性变换输出
    拼接后的结果再过一个线性层,输出最终向量

📌 举个数值例子(简化):

假设每个学生发言是 4 维向量,有 3 个学生:

学生发言 = [
    [0.1, 0.2, 0.3, 0.4],  # 张三
    [0.5, 0.6, 0.7, 0.8],  # 李四
    [0.9, 1.0, 1.1, 1.2],  # 王五
]

num_heads=2,每头负责 2 维。MultiheadAttention 会:

  • 把每个向量切成两半:[0.1,0.2] 和 [0.3,0.4] → 分别给头1和头2
  • 每个头独立计算“谁关注谁”
  • 最后把两个头的结果拼起来,再线性变换 → 输出新向量

🌍 三、为什么用 MultiheadAttention?—— 应用场景

  1. Transformer 模型核心组件(BERT、GPT、ViT 等都靠它!)
  2. 机器翻译:翻译“猫坐在垫子上”时,模型要知道“猫”和“坐”关系密切
  3. 文本摘要:找出句子中最重要的词
  4. 图像识别(ViT):把图像切成块,块之间用注意力找关系
  5. 推荐系统:用户历史行为中,哪些物品更重要?

📌 实际例子:

你训练一个聊天机器人,输入是“我昨天买了一只猫,它很可爱”。
MultiheadAttention 能让“它”自动关联到“猫”,而不是“昨天”或“买”。


💻 四、PyTorch 中怎么用?—— 代码示例(课堂讨论版)

import torch
import torch.nn as nn

# 假设有3个学生,每人发言用4维向量表示
# shape: [sequence_length, batch_size, embedding_dim] = [3, 1, 4]
student_discussions = torch.tensor([
    [[0.1, 0.2, 0.3, 0.4]],  # 张三
    [[0.5, 0.6, 0.7, 0.8]],  # 李四
    [[0.9, 1.0, 1.1, 1.2]],  # 王五
])

print("原始发言:")
print(student_discussions.squeeze())  # shape: [3, 4]

# 创建 MultiheadAttention:4维输入,2个头(每头2维),输出4维
attention = nn.MultiheadAttention(embed_dim=4, num_heads=2, batch_first=False)

# 通常 Q=K=V(自注意力)
output, attention_weights = attention(
    student_discussions,   # query
    student_discussions,   # key
    student_discussions    # value
)

print("\n注意力更新后的发言:")
print(output.squeeze())  # shape: [3, 4]

print("\n注意力权重矩阵(3学生 x 3学生):")
print(attention_weights.squeeze().detach().numpy().round(2))

📌 输出示例(数值随机,因未训练):

原始发言:
tensor([[0.1000, 0.2000, 0.3000, 0.4000],
        [0.5000, 0.6000, 0.7000, 0.8000],
        [0.9000, 1.0000, 1.1000, 1.2000]])

注意力更新后的发言:
tensor([[0.1234, 0.2345, 0.3456, 0.4567],
        [0.5678, 0.6789, 0.7890, 0.8901],
        [0.9012, 1.0123, 1.1234, 1.2345]], grad_fn=<SqueezeBackward1>)

注意力权重矩阵(3学生 x 3学生):
[[0.33 0.33 0.33]
 [0.30 0.40 0.30]
 [0.25 0.25 0.50]]

🔍 说明:

  • 输出:每个学生的新发言,融合了其他人的信息
  • 注意力权重:第 i 行表示“第 i 个学生在更新自己时,对其他学生的关注比例”
    • 比如王五(第2行)最关注自己(0.50),说明他比较自信 😄

🧠 五、训练时发生了什么?

MultiheadAttention 内部有可训练参数:

  • 投影矩阵:把输入映射到 Q、K、V 空间(in_proj_weight
  • 输出投影矩阵:把多头结果合并(out_proj.weight

在训练中,这些矩阵会不断更新,让模型学会:

  • “当看到代词‘它’时,应该回头关注名词‘猫’”
  • “在翻译时,法语‘chat’应该对齐英语‘cat’”

✅ 六、初学者常见问题

  1. Q、K、V 必须一样吗?
    → 不一定!自注意力(如Transformer编码器)时 Q=K=V;
    交叉注意力(如Transformer解码器)时 Q 来自目标,K、V 来自源。

  2. num_heads 怎么选?
    → 通常是 embed_dim 的因数(如 512 维 → 8头,每头64维)。常用 8、12、16。

  3. 输入形状是什么?
    默认:(seq_len, batch_size, embed_dim)
    如果设 batch_first=True,则是 (batch_size, seq_len, embed_dim)

  4. 需要自己实现 softmax 和矩阵乘法吗?
    → 不需要!nn.MultiheadAttention 封装好了,直接调用即可。

  5. 和 CNN/RNN 比有什么优势?
    → 能直接建模长距离依赖(“第一个词”和“最后一个词”可以直接互动),并行计算快。


🎉 七、总结口诀(方便记忆):

“多头齐观察,加权汇精华,序列找关系,Transformer 靠它!”


📚 举一反三:

  • embed_dim=512, num_heads=8 → 标准 Transformer 配置
  • 用于文本:每个词是一个“学生”,注意力找词间关系
  • 用于图像(ViT):每个图像块是一个“学生”,注意力找空间关系
  • 用于语音:每个时间帧是一个“学生”,注意力找时序重点

✅ 最后提醒:

MultiheadAttention 是现代AI的基石之一,看似复杂,但核心思想就是“加权融合”。从“小组讨论”角度理解,你就赢了!


希望这个“课堂讨论”的例子,让你彻底理解了 MultiheadAttention!它不是魔法,而是“聪明的信息融合机制”🧠