多头注意力详解(通俗易懂版)

105 阅读3分钟

废话不多说直接上代码再来讲解

注意看注释,关键代码来看到forward

import torch
from torch import nn

class MultiHeadAttention(nn.Module):
    """多头注意力"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
        num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        # self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
  
    def forward(self, queries, keys, values, valid_lens):
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)
        if valid_lens is not None:
            # 在轴0,将第一项(标量或者矢量)复制num_heads次,
            # 然后如此复制第二项,然后诸如此类。
            valid_lens = torch.repeat_interleave(
            valid_lens, repeats=self.num_heads, dim=0)
            # tensor([3, 3, 3, 3, 3, 2, 2, 2, 2, 2])
            # output的形状:(batch_size*num_heads,查询的个数,
            # num_hiddens/num_heads)
            output = self.attention(queries, keys, values, valid_lens)
            # output_concat的形状:(batch_size,查询的个数,num_hiddens)
            output_concat = transpose_output(output, self.num_heads)#还原形状
            return self.W_o(output_concat)

def transpose_qkv(X, num_heads):
    """为了多注意力头的并行计算而变换形状"""

    # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
    # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,
    # num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    # torch.Size([2, 5, 20,4])
    # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,
    # num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)#TODO:交换张量,1,2维的位置

    # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,
    # num_hiddens/num_heads)
    # torch.Size([2, 5, 4, 20])
    #TODO:可以理解为 句1,头1;句1,头2...
    return X.reshape(-1, X.shape[2], X.shape[3])

# @save
def transpose_output(X, num_heads):
    """逆转transpose_qkv函数的操作"""

    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)
  • 我们先来讲输入吧,(batch_size, num_kvpairs, num_hiddens)按照(2,4,10,num_heads = 2
  • 这里的2 可以理解为 2 个句子,4指的是每个句子有四个词,10指的是每个词有十个特征值
  • 首先我们讲到多头,多头是哪里多头呢?这里我们设置了两个头,就是需要把十维的特征向量 分割成两头,就是 [2,4,2,5],每头有五维的特征
  • 就是transpose_qkv函数里面的用法, X = X.permute(0, 2, 1, 3),这段代码就是交换位置1和位置2 的位置,就变成了[句子数,头数,词数,特征数],我们在这里可以把头数理解为,一个词的不同特征维度,比如说,动词,名词,形容词之类的。
  • 然后transpose_qkv函数的返回值,注意了一个变换,X.reshape(-1, X.shape[2], X.shape[3]),把句子数和头数合并了,所以我们可以这样理解,他现在成了这样一个形状:
  • 句子1,头1,句子1,头2 ;句子2,头1,句子2,头2
  • 所有Q,K,V都是进行这样一个多头的转换操作
  • valid_lens = torch.repeat_interleave( valid_lens, repeats=self.num_heads, dim=0) # tensor([3, 3, 2, 2])
  • 有人会问这段代码是什么意思,其实valid_lens就是限制了每个句子的词数,然后通过后续的遮掩操作进行处理,即句1头1 的词数限制在3,句1头2 的词数限制在3,往后同理。
  • 然后就是DotProductAttention这个操作了,这个是做什么的?
  • [参考上篇文章](缩放点击注意力推导点击缩放注意力推导 点击缩放评分公式为:a(q,k) / d ** -1 这里提出一个问题,为什么要除 - 掘金),直接计算。
  • 最后我们是一个将多头注意力,恢复成原来的形式的操作。

好啦今天就讲到这里,具体的例子的话,可以将代码复制给ai,他会给你举一个简单的例子,或者是自己根据代码推导一遍哦,(最好先看完上一篇的缩放点击注意力的内容)