废话不多说直接上代码再来讲解
注意看注释,关键代码来看到forward
import torch
from torch import nn
class MultiHeadAttention(nn.Module):
"""多头注意力"""
def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
def forward(self, queries, keys, values, valid_lens):
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
valid_lens = torch.repeat_interleave(
valid_lens, repeats=self.num_heads, dim=0)
output = self.attention(queries, keys, values, valid_lens)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)
def transpose_qkv(X, num_heads):
"""为了多注意力头的并行计算而变换形状"""
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
X = X.permute(0, 2, 1, 3)
return X.reshape(-1, X.shape[2], X.shape[3])
def transpose_output(X, num_heads):
"""逆转transpose_qkv函数的操作"""
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
- 我们先来讲输入吧,(batch_size, num_kvpairs, num_hiddens)按照(2,4,10,num_heads = 2
- 这里的2 可以理解为 2 个句子,4指的是每个句子有四个词,10指的是每个词有十个特征值
- 首先我们讲到多头,多头是哪里多头呢?这里我们设置了两个头,就是需要把十维的特征向量 分割成两头,就是 [2,4,2,5],每头有五维的特征
- 就是transpose_qkv函数里面的用法, X = X.permute(0, 2, 1, 3),这段代码就是交换位置1和位置2 的位置,就变成了[句子数,头数,词数,特征数],我们在这里可以把头数理解为,一个词的不同特征维度,比如说,动词,名词,形容词之类的。
- 然后transpose_qkv函数的返回值,注意了一个变换,X.reshape(-1, X.shape[2], X.shape[3]),把句子数和头数合并了,所以我们可以这样理解,他现在成了这样一个形状:
- 句子1,头1,句子1,头2 ;句子2,头1,句子2,头2
- 所有Q,K,V都是进行这样一个多头的转换操作
- valid_lens = torch.repeat_interleave(
valid_lens, repeats=self.num_heads, dim=0)
# tensor([3, 3, 2, 2])
- 有人会问这段代码是什么意思,其实valid_lens就是限制了每个句子的词数,然后通过后续的遮掩操作进行处理,即句1头1 的词数限制在3,句1头2 的词数限制在3,往后同理。
- 然后就是DotProductAttention这个操作了,这个是做什么的?
- [参考上篇文章](缩放点击注意力推导点击缩放注意力推导 点击缩放评分公式为:a(q,k) / d ** -1 这里提出一个问题,为什么要除 - 掘金),直接计算。
- 最后我们是一个将多头注意力,恢复成原来的形式的操作。
好啦今天就讲到这里,具体的例子的话,可以将代码复制给ai,他会给你举一个简单的例子,或者是自己根据代码推导一遍哦,(最好先看完上一篇的缩放点击注意力的内容)