多头注意力机制
为权重矩阵, 为权重输入
对 三个矩阵由 (batch_size, seq_len, embedding_dim) 重塑为 (batch_size, seq_len, num_heads, head_dim) 之后继续操作
其中 embedding_dim = num_heads * head_dim 必须能被整除才可以拆分
其中
是head_dim 大小的常量, 为权重矩阵
手动实现一下大致是这样
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.1):
super(MultiHeadAttention, self).__init__()
self.n_heads = n_heads
self.d_model = d_model
self.d_k = d_model // n_heads
self.Wq, self.Wk, self.Wv = (
nn.Linear(d_model, d_model),
nn.Linear(d_model, d_model), nn.Linear(d_model, d_model)
)
self.Wo = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(p=dropout)
def forward(self, q, k, v, mask=None):
q, k, v = self.Wq(q), self.Wk(k), self.Wv(v)
q, k, v = self._split_heads(q), self._split_heads(k), self._split_heads(v)
scores = torch.matmul(q, k.transpose(-2, -1)) / self.d_k
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1)
# (batch_size, num_heads, seq_len, seq_len)
attn_weights = self.dropout(attn_weights)
weighted_sum = torch.matmul(attn_weights, v)
# (batch_size, num_heads, seq_len, d_v)
# Concatenate heads
weighted_sum = weighted_sum.transpose(1, 2)
weighted_sum = weighted_sum.reshape(weighted_sum.size(0), -1, self.d_model)
# (batch_size, seq_len, d_model)
# Final linear transformation
output = self.Wo(weighted_sum) # (batch_size, seq_len, d_model)
return output, attn_weights
def _split_heads(self, x):
batch_size, seq_len, d_model = x.size()
d_head = d_model // self.n_heads
x = x.view(batch_size, seq_len, self.n_heads, d_head)
return x.permute(0, 2, 1, 3) # (batch_size, n_heads, seq_len, d_head)
multi_head_attention = MultiHeadAttention(d_model=512, n_heads=16, dropout=0.1)
x = torch.rand((2, 512, 512))
output, attn_weights = multi_head_attention(x, x, x)
print(output.size())
print(attn_weights.size())