简介
DecoderLayer类定义了Transformer解码器的单层。它由多头自注意力机制、多头交叉注意力机制(关注编码器的输出)、位置感知前馈神经网络以及相应的残差连接、层归一化和dropout层组成。这种组合使解码器能够根据编码器的表示生成有意义的输出,同时考虑目标序列和源序列。与编码器一样,通常会堆叠多个解码器层以形成完整的解码器部分。
全部代码
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout):
super(DecoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.cross_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, enc_output, src_mask, tgt_mask):
attn_output = self.self_attn(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout(attn_output))
attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
x = self.norm2(x + self.dropout(attn_output))
ff_output = self.feed_forward(x)
x = self.norm3(x + self.dropout(ff_output))
return x
类定义/初始化
参数:
- d_model: 输入的维度。
- num_heads: 多头注意力中的注意力头数。
- d_ff: 前馈网络中内层的维度。
- dropout: 用于正则化的dropout率。
组件:
- self.self_attn: 目标序列的多头自注意力机制。
- self.cross_attn: 多头注意力机制,用于关注编码器的输出。
- self.feed_forward: 位置感知前馈神经网络。
- self.norm1, self.norm2, self.norm3: 层归一化组件。
- self.dropout: 用于正则化的dropout层。
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout):
super(DecoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.cross_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
前向方法
输入:
- x: 解码器层的输入。
- enc_output: 对应编码器的输出(用于交叉注意力步骤)。
- src_mask: 源掩码,用于忽略编码器输出的某些部分。
- tgt_mask: 目标掩码,用于忽略解码器输入的某些部分。
处理步骤:
- 目标序列上的自注意力:输入x通过自注意力机制处理。
- 加法 & 归一化(自注意力后):自注意力的输出添加到原始x,然后是dropout和使用norm1的归一化。
- 编码器输出上的交叉注意力:前一步归一化输出通过交叉注意力机制处理,该机制关注编码器的输出enc_output。
- 加法 & 归一化(交叉注意力后):交叉注意力的输出添加到该阶段的输入,然后是dropout和使用norm2的归一化。
- 前馈网络:前一步的输出通过前馈网络传递。
- 加法 & 归一化(前馈后):前馈输出添加到该阶段的输入,然后是dropout和使用norm3的归一化。
- 输出:处理后的张量作为解码器层的输出返回。
def forward(self, x, enc_output, src_mask, tgt_mask):
attn_output = self.self_attn(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout(attn_output))
attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
x = self.norm2(x + self.dropout(attn_output))
ff_output = self.feed_forward(x)
x = self.norm3(x + self.dropout(ff_output))
return x