多头注意力机制计算序列中每对位置之间的注意力。它由多个“注意力头”组成,这些头捕获输入序列的不同方面。
简介
MultiHeadAttention类封装了在Transformer模型中常用的多头注意力机制。它处理将输入分割成多个注意力头,对每个头应用注意力,然后将结果组合。通过这样做,模型可以在不同尺度上捕获输入数据中的各种关系,提高模型的表达能力。
全部代码
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
# 确保模型维度(d_model)可以被注意力头数整除
assert d_model % num_heads == 0, "d_model必须能被num_heads整除"
# 初始化维度
self.d_model = d_model # 模型的维度
self.num_heads = num_heads # 注意力头的数量
self.d_k = d_model // num_heads # 每个头的键、查询和值的维度
# 用于转换输入的线性层
self.W_q = nn.Linear(d_model, d_model) # 查询转换
self.W_k = nn.Linear(d_model, d_model) # 键转换
self.W_v = nn.Linear(d_model, d_model) # 值转换
self.W_o = nn.Linear(d_model, d_model) # 输出转换
def scaled_dot_product_attention(self, Q, K, V, mask=None):
# 计算注意力分数
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# 如果提供了掩码,则应用它(有助于防止对某些部分如填充的注意力)
if mask is not None:
attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
# 应用softmax以获得注意力概率
attn_probs = torch.softmax(attn_scores, dim=-1)
# 乘以值以获得最终输出
output = torch.matmul(attn_probs, V)
return output
def split_heads(self, x):
# 重塑输入以进行多头注意力
batch_size, seq_length, d_model = x.size()
return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
def combine_heads(self, x):
# 将多个头重新组合成原始形状
batch_size, _, seq_length, d_k = x.size()
return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
def forward(self, Q, K, V, mask=None):
# 应用线性变换并分割头
Q = self.split_heads(self.W_q(Q))
K = self.split_heads(self.W_k(K))
V = self.split_heads(self.W_v(V))
# 执行缩放点积注意力
attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
# 组合头并应用输出变换
output = self.W_o(self.combine_heads(attn_output))
return output
类定义和初始化
- 定义为PyTorch的nn.Module的子类
- d_model: 输入的维度。
- num_heads: 分割输入的注意力头数。
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads)
缩放点积注意力
- 计算注意力分数:attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)。这里,注意力分数是通过取查询(Q)和键(K)的点积,然后通过键的维度(d_k)的平方根进行缩放来计算的。
- 应用掩码:如果提供了掩码,则将其应用于注意力分数以掩盖特定值。
- 计算注意力权重:注意力分数通过softmax函数传递,以将它们转换为总和为1的概率。
- 计算输出:注意力的最终输出是通过将注意力权重乘以值(V)来计算的。
def scaled_dot_product_attention(self, Q, K, V, mask=None)
分割头
这个方法将输入x重塑为形状(batch_size, num_heads, seq_length, d_k)。它使模型能够同时处理多个注意力头,允许并行计算。
def split_heads(self, x)
组合头
在分别对每个头应用注意力后,这个方法将结果组合回batch_size, seq_length, d_model形状的单个张量。这为进一步处理准备了结果。
def combine_heads(self, x)
前向方法
前向方法是实际计算发生的地方:
- 应用线性变换:首先使用初始化中定义的权重将查询(Q)、键(K)和值(V)通过线性变换。
- 分割头:使用split_heads方法将转换后的Q、K、V分割成多个头。
- 应用缩放点积注意力:在分割的头上调用scaled_dot_product_attention方法。
- 组合头:使用combine_heads方法将每个头的结果组合回单个张量。
- 应用输出变换:最后,组合的张量通过输出线性变换。
def forward(self, Q, K, V, mask=None)