Transformer构建解码器块

184 阅读2分钟

decoder.png

简介

DecoderLayer类定义了Transformer解码器的单层。它由多头自注意力机制、多头交叉注意力机制(关注编码器的输出)、位置感知前馈神经网络以及相应的残差连接、层归一化和dropout层组成。这种组合使解码器能够根据编码器的表示生成有意义的输出,同时考虑目标序列和源序列。与编码器一样,通常会堆叠多个解码器层以形成完整的解码器部分。

全部代码

class DecoderLayer(nn.Module):  
    def __init__(self, d_model, num_heads, d_ff, dropout):  
        super(DecoderLayer, self).__init__()  
        self.self_attn = MultiHeadAttention(d_model, num_heads)  
        self.cross_attn = MultiHeadAttention(d_model, num_heads)  
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)  
        self.norm1 = nn.LayerNorm(d_model)  
        self.norm2 = nn.LayerNorm(d_model)  
        self.norm3 = nn.LayerNorm(d_model)  
        self.dropout = nn.Dropout(dropout)  
  
    def forward(self, x, enc_output, src_mask, tgt_mask):  
        attn_output = self.self_attn(x, x, x, tgt_mask)  
        x = self.norm1(x + self.dropout(attn_output))  
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)  
        x = self.norm2(x + self.dropout(attn_output))  
        ff_output = self.feed_forward(x)  
  
       x = self.norm3(x + self.dropout(ff_output))  
        return x  

类定义/初始化

参数:

  1. d_model: 输入的维度。
  2. num_heads: 多头注意力中的注意力头数。
  3. d_ff: 前馈网络中内层的维度。
  4. dropout: 用于正则化的dropout率。

组件:

  1. self.self_attn: 目标序列的多头自注意力机制。
  2. self.cross_attn: 多头注意力机制,用于关注编码器的输出。
  3. self.feed_forward: 位置感知前馈神经网络。
  4. self.norm1, self.norm2, self.norm3: 层归一化组件。
  5. self.dropout: 用于正则化的dropout层。
class DecoderLayer(nn.Module):  
    def __init__(self, d_model, num_heads, d_ff, dropout):  
        super(DecoderLayer, self).__init__()  
        self.self_attn = MultiHeadAttention(d_model, num_heads)  
        self.cross_attn = MultiHeadAttention(d_model, num_heads)  
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)  
        self.norm1 = nn.LayerNorm(d_model)  
        self.norm2 = nn.LayerNorm(d_model)  
        self.norm3 = nn.LayerNorm(d_model)  
        self.dropout = nn.Dropout(dropout) 

前向方法

输入:

  1. x: 解码器层的输入。
  2. enc_output: 对应编码器的输出(用于交叉注意力步骤)。
  3. src_mask: 源掩码,用于忽略编码器输出的某些部分。
  4. tgt_mask: 目标掩码,用于忽略解码器输入的某些部分。

处理步骤:

  1. 目标序列上的自注意力:输入x通过自注意力机制处理。
  2. 加法 & 归一化(自注意力后):自注意力的输出添加到原始x,然后是dropout和使用norm1的归一化。
  3. 编码器输出上的交叉注意力:前一步归一化输出通过交叉注意力机制处理,该机制关注编码器的输出enc_output。
  4. 加法 & 归一化(交叉注意力后):交叉注意力的输出添加到该阶段的输入,然后是dropout和使用norm2的归一化。
  5. 前馈网络:前一步的输出通过前馈网络传递。
  6. 加法 & 归一化(前馈后):前馈输出添加到该阶段的输入,然后是dropout和使用norm3的归一化。
  7. 输出:处理后的张量作为解码器层的输出返回。
def forward(self, x, enc_output, src_mask, tgt_mask):  
    attn_output = self.self_attn(x, x, x, tgt_mask)  
    x = self.norm1(x + self.dropout(attn_output))  
    attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)  
    x = self.norm2(x + self.dropout(attn_output))  
    ff_output = self.feed_forward(x)  
    x = self.norm3(x + self.dropout(ff_output))  
    return x