Transformer多头注意力

184 阅读3分钟

多头注意力机制计算序列中每对位置之间的注意力。它由多个“注意力头”组成,这些头捕获输入序列的不同方面。

multiHead.png

简介

MultiHeadAttention类封装了在Transformer模型中常用的多头注意力机制。它处理将输入分割成多个注意力头,对每个头应用注意力,然后将结果组合。通过这样做,模型可以在不同尺度上捕获输入数据中的各种关系,提高模型的表达能力。

全部代码

class MultiHeadAttention(nn.Module):  
    def __init__(self, d_model, num_heads):  
        super(MultiHeadAttention, self).__init__()  
        # 确保模型维度(d_model)可以被注意力头数整除  
        assert d_model % num_heads == 0, "d_model必须能被num_heads整除"  
  
        # 初始化维度  
        self.d_model = d_model  # 模型的维度  
        self.num_heads = num_heads  # 注意力头的数量  
        self.d_k = d_model // num_heads  # 每个头的键、查询和值的维度  
  
        # 用于转换输入的线性层  
        self.W_q = nn.Linear(d_model, d_model)  # 查询转换  
        self.W_k = nn.Linear(d_model, d_model)  # 键转换  
        self.W_v = nn.Linear(d_model, d_model)  # 值转换  
        self.W_o = nn.Linear(d_model, d_model)  # 输出转换  
  
    def scaled_dot_product_attention(self, Q, K, V, mask=None):  
        # 计算注意力分数  
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)  
  
        # 如果提供了掩码,则应用它(有助于防止对某些部分如填充的注意力)  
        if mask is not None:  
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)  
  
        # 应用softmax以获得注意力概率  
        attn_probs = torch.softmax(attn_scores, dim=-1)  
  
        # 乘以值以获得最终输出  
        output = torch.matmul(attn_probs, V)  
        return output  
  
    def split_heads(self, x):  
        # 重塑输入以进行多头注意力  
        batch_size, seq_length, d_model = x.size()  
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)  
  
    def combine_heads(self, x):  
        # 将多个头重新组合成原始形状  
        batch_size, _, seq_length, d_k = x.size()  
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)  
  
    def forward(self, Q, K, V, mask=None):  
        # 应用线性变换并分割头  
        Q = self.split_heads(self.W_q(Q))  
        K = self.split_heads(self.W_k(K))  
        V = self.split_heads(self.W_v(V))  
  
        # 执行缩放点积注意力  
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)  
  
        # 组合头并应用输出变换  
        output = self.W_o(self.combine_heads(attn_output))  
        return output  

类定义和初始化

  1. 定义为PyTorch的nn.Module的子类
  2. d_model: 输入的维度。
  3. num_heads: 分割输入的注意力头数。
class MultiHeadAttention(nn.Module):  
    def __init__(self, d_model, num_heads) 

缩放点积注意力

  1. 计算注意力分数:attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)。这里,注意力分数是通过取查询(Q)和键(K)的点积,然后通过键的维度(d_k)的平方根进行缩放来计算的。
  2. 应用掩码:如果提供了掩码,则将其应用于注意力分数以掩盖特定值。
  3. 计算注意力权重:注意力分数通过softmax函数传递,以将它们转换为总和为1的概率。
  4. 计算输出:注意力的最终输出是通过将注意力权重乘以值(V)来计算的。
def scaled_dot_product_attention(self, Q, K, V, mask=None) 

分割头

这个方法将输入x重塑为形状(batch_size, num_heads, seq_length, d_k)。它使模型能够同时处理多个注意力头,允许并行计算。

def split_heads(self, x)  

组合头

在分别对每个头应用注意力后,这个方法将结果组合回batch_size, seq_length, d_model形状的单个张量。这为进一步处理准备了结果。

def combine_heads(self, x)  

前向方法

前向方法是实际计算发生的地方:

  1. 应用线性变换:首先使用初始化中定义的权重将查询(Q)、键(K)和值(V)通过线性变换。
  2. 分割头:使用split_heads方法将转换后的Q、K、V分割成多个头。
  3. 应用缩放点积注意力:在分割的头上调用scaled_dot_product_attention方法。
  4. 组合头:使用combine_heads方法将每个头的结果组合回单个张量。
  5. 应用输出变换:最后,组合的张量通过输出线性变换。
def forward(self, Q, K, V, mask=None)