理解多头注意力机制:像乐队合奏一样处理信息

35 阅读1分钟

一、为什么需要多个"注意力头"?

想象你正在参加一场交响乐演出,每个乐手都专注乐谱的不同部分——小提琴组负责主旋律,打击乐把控节奏,铜管组强调高潮段落。这种分工协作的方式,正是多头注意力机制的核心思想。

传统单头注意力就像只有一位听众在欣赏音乐,只能从一个角度理解整个演奏。而多头注意力让多个"虚拟听众"(头)同时工作,每个头都能:

  1. 捕捉不同距离的关联(如主歌与副歌的关系)
  2. 关注不同类型的特征(旋律、节奏、和声)
  3. 组合多种理解方式形成全面认知

multi-head-attention.svg

图1 多头注意力:多个头连结然后线性变换

二、多头注意力如何运作?

2.1 核心计算步骤

假设我们要处理一句歌词:"雨下整夜,我的爱溢出就像雨水"。每个词的表示向量都要与其它词产生关联,具体分为三步:

步骤1:创建多重视角

为每个头创建独立视角


  • 头1_查询 = 线性变换(原始查询) Wq(1)Q\boxed{W_q^{(1)}Q}
  • 头1_键 = 线性变换(原始键) Wk(1)K\boxed{W_k^{(1)}K}
  • 头1_值 = 线性变换(原始值) Wv(1)V\boxed{W_v^{(1)}V}

  • 头2_查询 = 线性变换(原始查询) Wq(2)Q\boxed{W_q^{(2)}Q}
  • ...(共h个头)

数学表达式(每个头i的计算):

headi=Attention(Wq(i)Q,Wk(i)K,Wv(i)V)\text{head}_i = \text{Attention}(W_q^{(i)}Q, W_k^{(i)}K, W_v^{(i)}V)

步骤2:并行注意力计算

每个头独立进行注意力计算(以缩放点积注意力为例):

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q,K,V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V

步骤3:合并所有结果

将各头的输出拼接后做最终变换:

MultiHead=Wo[head1;head2;...;headh]\text{MultiHead} = W_o[\text{head}_1; \text{head}_2; ...; \text{head}_h]

2.2 维度变换图解

假设原始维度d=64,使用8个头:

  1. 每个头的维度变为64/8=8
  2. 各头计算结果拼接后恢复64维
  3. 最终线性变换保持维度一致

多头注意力图解.png

三、亲手搭建迷你多头注意力

3.1 简化版实现(使用PyTorch)

def transpose_qkv(X, num_heads):
    """为了多注意力头的并行计算而变换形状"""

    # 输入X的形状: (batch_size, 查询或者“键-值”对的个数, num_hiddens)
    # 输出X的形状: (batch_size, 查询或者“键-值”对的个数, num_heads,num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    # 输出X的形状: (batch_size, num_heads, 查询或者“键-值”对的个数, num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)
    # 最终输出的形状: (batch_size*num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])


def transpose_output(X, num_heads):
    """逆转transpose_qkv函数的操作"""

    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)


class MultiHeadAttention(nn.Module):
    """多头注意力"""

    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super().__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # queries,keys,values的形状:
        # (batch_size,查询或者“键-值”对的个数,num_hiddens)

        # valid_lens 的形状:
        # (batch_size,)或(batch_size, 查询的个数)

        # 经过变换后,输出的queries, keys, values 的形状:
        # (batch_size*num_heads, 查询或者“键-值”对的个数, num_hiddens/num_heads)

        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # 在轴0,将第一项(标量或者矢量)复制num_heads次,
            # 然后如此复制第二项,然后诸如此类。
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        # output的形状: (batch_size*num_heads, 查询的个数, num_hiddens/num_heads)
        output = self.attention(queries, keys, values, valid_lens)

        # output_concat的形状:(batch_size, 查询的个数, num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

下面使用键和值相同的小例子来测试我们编写的MultiHeadAttention类。多头注意力输出的形状是(batch_size, num_queries, num_hiddens)

import torch

import d2l

num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(key_size=num_hiddens, query_size=num_hiddens, value_size=num_hiddens,
                                   num_hiddens=num_hiddens, num_heads=num_heads, dropout=0.5)
attention.eval()

print(attention)
MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)
batch_size, num_queries = 2, 4  # num_queries: 查询的个数
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))  # queries
print(X.shape)
# torch.Size([2, 4, 100]) (batch_size, num_queries, num_hiddens)

Y = torch.ones((batch_size, num_kvpairs, num_hiddens))  # keys, values
print(Y.shape)
# torch.Size([2, 6, 100]) (batch_size, num_kvpairs, num_hiddens)

print(attention(X, Y, Y, valid_lens).shape)
# torch.Size([2, 4, 100]) (batch_size, num_queries, num_hiddens)

3.2 关键技巧解析

  1. 维度拆分:将64维拆分为8个8维头
# 伪代码
q = q.view(batch_size, seq_len, 8, 8).transpose(1,2)
  1. 并行计算:利用矩阵运算同时处理所有头
  2. 结果融合:拼接后通过线性层整合信息

四、实际应用示例:歌词情感分析

假设分析周杰伦《七里香》歌词的情感:

歌词 = ["窗外的麻雀", "在电线杆上多嘴", 
       "你说这一句", "很有夏天的感觉"]

# 创建词向量(假设已编码)
词向量 = torch.randn(4, 64)  # 4个词,每个64维

# 使用迷你多头注意力
注意力输出 = MiniMultiHead()(词向量, 词向量, 词向量)

print("每个词的新表示维度:", 注意力输出.shape)
# 输出: torch.Size([4, 64])

此时每个词的表示都融合了:

  • "麻雀"与"电线杆"的位置关系(空间头)
  • "多嘴"与"感觉"的情感关联(语义头)
  • "夏天"与整句的意境联系(语境头)

五、技术要点总结

关键概念类比解释数学表达
线性投影给每个头配不同颜色的眼镜Wq(i)QW_q^{(i)}Q
头拼接乐队各声部录音的合并[head1;...;headh][head_1;...;head_h]
缩放点积计算词语间的匹配分数QKTdk\frac{QK^T}{\sqrt{d_k}}
最终投影指挥家统一协调各声部WoW_o

多头注意力的三大优势

  1. 并行处理:多个头同时计算,效率提升
  2. 多样化关注:捕获词语间的不同类型关系
  3. 强大表征:通过线性变换组合复杂特征
多头注意力=多个视角+并行计算+智能融合\text{多头注意力} = \text{多个视角} + \text{并行计算} + \text{智能融合}

理解多头注意力机制,就像学会用多种角度欣赏音乐。当每个"头"专注不同声部,最终合奏出的,便是深度学习最动人的智能交响。