一、为什么需要多个"注意力头"?
想象你正在参加一场交响乐演出,每个乐手都专注乐谱的不同部分——小提琴组负责主旋律,打击乐把控节奏,铜管组强调高潮段落。这种分工协作的方式,正是多头注意力机制的核心思想。
传统单头注意力就像只有一位听众在欣赏音乐,只能从一个角度理解整个演奏。而多头注意力让多个"虚拟听众"(头)同时工作,每个头都能:
- 捕捉不同距离的关联(如主歌与副歌的关系)
- 关注不同类型的特征(旋律、节奏、和声)
- 组合多种理解方式形成全面认知
二、多头注意力如何运作?
2.1 核心计算步骤
假设我们要处理一句歌词:"雨下整夜,我的爱溢出就像雨水"。每个词的表示向量都要与其它词产生关联,具体分为三步:
步骤1:创建多重视角
为每个头创建独立视角
- 头1_查询 = 线性变换(原始查询)
- 头1_键 = 线性变换(原始键)
- 头1_值 = 线性变换(原始值)
- 头2_查询 = 线性变换(原始查询)
- ...(共h个头)
数学表达式(每个头i的计算):
步骤2:并行注意力计算
每个头独立进行注意力计算(以缩放点积注意力为例):
步骤3:合并所有结果
将各头的输出拼接后做最终变换:
2.2 维度变换图解
假设原始维度d=64,使用8个头:
- 每个头的维度变为64/8=8
- 各头计算结果拼接后恢复64维
- 最终线性变换保持维度一致
三、亲手搭建迷你多头注意力
3.1 简化版实现(使用PyTorch)
def transpose_qkv(X, num_heads):
"""为了多注意力头的并行计算而变换形状"""
# 输入X的形状: (batch_size, 查询或者“键-值”对的个数, num_hiddens)
# 输出X的形状: (batch_size, 查询或者“键-值”对的个数, num_heads,num_hiddens/num_heads)
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
# 输出X的形状: (batch_size, num_heads, 查询或者“键-值”对的个数, num_hiddens/num_heads)
X = X.permute(0, 2, 1, 3)
# 最终输出的形状: (batch_size*num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads)
return X.reshape(-1, X.shape[2], X.shape[3])
def transpose_output(X, num_heads):
"""逆转transpose_qkv函数的操作"""
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
class MultiHeadAttention(nn.Module):
"""多头注意力"""
def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
super().__init__(**kwargs)
self.num_heads = num_heads
self.attention = DotProductAttention(dropout)
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
def forward(self, queries, keys, values, valid_lens):
# queries,keys,values的形状:
# (batch_size,查询或者“键-值”对的个数,num_hiddens)
# valid_lens 的形状:
# (batch_size,)或(batch_size, 查询的个数)
# 经过变换后,输出的queries, keys, values 的形状:
# (batch_size*num_heads, 查询或者“键-值”对的个数, num_hiddens/num_heads)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
# 在轴0,将第一项(标量或者矢量)复制num_heads次,
# 然后如此复制第二项,然后诸如此类。
valid_lens = torch.repeat_interleave(
valid_lens, repeats=self.num_heads, dim=0)
# output的形状: (batch_size*num_heads, 查询的个数, num_hiddens/num_heads)
output = self.attention(queries, keys, values, valid_lens)
# output_concat的形状:(batch_size, 查询的个数, num_hiddens)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)
下面使用键和值相同的小例子来测试我们编写的MultiHeadAttention
类。多头注意力输出的形状是(batch_size, num_queries, num_hiddens)
。
import torch
import d2l
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(key_size=num_hiddens, query_size=num_hiddens, value_size=num_hiddens,
num_hiddens=num_hiddens, num_heads=num_heads, dropout=0.5)
attention.eval()
print(attention)
MultiHeadAttention(
(attention): DotProductAttention(
(dropout): Dropout(p=0.5, inplace=False)
)
(W_q): Linear(in_features=100, out_features=100, bias=False)
(W_k): Linear(in_features=100, out_features=100, bias=False)
(W_v): Linear(in_features=100, out_features=100, bias=False)
(W_o): Linear(in_features=100, out_features=100, bias=False)
)
batch_size, num_queries = 2, 4 # num_queries: 查询的个数
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens)) # queries
print(X.shape)
# torch.Size([2, 4, 100]) (batch_size, num_queries, num_hiddens)
Y = torch.ones((batch_size, num_kvpairs, num_hiddens)) # keys, values
print(Y.shape)
# torch.Size([2, 6, 100]) (batch_size, num_kvpairs, num_hiddens)
print(attention(X, Y, Y, valid_lens).shape)
# torch.Size([2, 4, 100]) (batch_size, num_queries, num_hiddens)
3.2 关键技巧解析
- 维度拆分:将64维拆分为8个8维头
# 伪代码
q = q.view(batch_size, seq_len, 8, 8).transpose(1,2)
- 并行计算:利用矩阵运算同时处理所有头
- 结果融合:拼接后通过线性层整合信息
四、实际应用示例:歌词情感分析
假设分析周杰伦《七里香》歌词的情感:
歌词 = ["窗外的麻雀", "在电线杆上多嘴",
"你说这一句", "很有夏天的感觉"]
# 创建词向量(假设已编码)
词向量 = torch.randn(4, 64) # 4个词,每个64维
# 使用迷你多头注意力
注意力输出 = MiniMultiHead()(词向量, 词向量, 词向量)
print("每个词的新表示维度:", 注意力输出.shape)
# 输出: torch.Size([4, 64])
此时每个词的表示都融合了:
- "麻雀"与"电线杆"的位置关系(空间头)
- "多嘴"与"感觉"的情感关联(语义头)
- "夏天"与整句的意境联系(语境头)
五、技术要点总结
关键概念 | 类比解释 | 数学表达 |
---|---|---|
线性投影 | 给每个头配不同颜色的眼镜 | |
头拼接 | 乐队各声部录音的合并 | |
缩放点积 | 计算词语间的匹配分数 | |
最终投影 | 指挥家统一协调各声部 |
多头注意力的三大优势:
- 并行处理:多个头同时计算,效率提升
- 多样化关注:捕获词语间的不同类型关系
- 强大表征:通过线性变换组合复杂特征
理解多头注意力机制,就像学会用多种角度欣赏音乐。当每个"头"专注不同声部,最终合奏出的,便是深度学习最动人的智能交响。