创建完整的Transformer网络

79 阅读3分钟

all.png

简介

Transformer类汇集了Transformer模型的各个组件,包括嵌入、位置编码、编码器层和解码器层。它为训练和推理提供了方便的接口,封装了多头注意力、前馈网络和层归一化的复杂性。

这个实现遵循标准的Transformer架构,适用于机器翻译、文本摘要等序列到序列任务。掩码的包含确保模型遵守序列内的因果依赖关系,忽略填充标记并防止来自未来标记的信息泄露。

这些连续的步骤使Transformer模型能够有效地处理输入序列并产生相应的输出序列。

全部代码

class Transformer(nn.Module):  
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):  
        super(Transformer, self).__init__()  
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)  
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)  
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)  
  
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])  
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])  
  
        self.fc = nn.Linear(d_model, tgt_vocab_size)  
        self.dropout = nn.Dropout(dropout)  
  
    def generate_mask(self, src, tgt):  
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)  
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)  
        seq_length = tgt.size(1)  
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()  
        tgt_mask = tgt_mask & nopeak_mask  
        return src_mask, tgt_mask  
  
    def forward(self, src, tgt):  
        src_mask, tgt_mask = self.generate_mask(src, tgt)  
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))  
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))  
  
        enc_output = src_embedded  
        for enc_layer in self.encoder_layers:  
            enc_output = enc_layer(enc_output, src_mask)  
  
        dec_output = tgt_embedded  
        for dec_layer in self.decoder_layers:  
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)  
  
        output = self.fc(dec_output)  
        return output  

类定义/初始化

参数:

  1. src_vocab_size: 源词汇表大小。
  2. tgt_vocab_size: 目标词汇表大小。
  3. d_model: 模型嵌入的维度。
  4. num_heads: 多头注意力机制中的注意力头数。
  5. num_layers: 编码器和解码器的层数。
  6. d_ff: 前馈网络中内层的维度。
  7. max_seq_length: 位置编码的最大序列长度。
  8. dropout: 用于正则化的dropout率。

组件:

  1. self.encoder_embedding: 源序列的嵌入层。
  2. self.decoder_embedding: 目标序列的嵌入层。
  3. self.positional_encoding: 位置编码组件。
  4. self.encoder_layers: 编码器层的列表。
  5. self.decoder_layers: 解码器层的列表。
  6. self.fc: 将输出映射到目标词汇表大小的最终全连接(线性)层。
  7. self.dropout: Dropout层。
class Transformer(nn.Module):  
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):  
        super(Transformer, self).__init__()  
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)  
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)  
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)  
  
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])  
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])  
  
        self.fc = nn.Linear(d_model, tgt_vocab_size)  
        self.dropout = nn.Dropout(dropout)  

生成掩码方法

用于为源和目标序列创建掩码,确保在训练期间忽略填充标记,并且目标序列中的未来标记不可见。

def generate_mask(self, src, tgt): 

前向方法

此方法定义了Transformer的前向传递,采用源和目标序列,并产生输出预测。

  1. 输入嵌入和位置编码:首先使用各自的嵌入层嵌入源和目标序列,然后添加它们的位置编码。
  2. 编码器层:源序列通过编码器层传递,最终编码器输出表示处理过的源序列。
  3. 解码器层:目标序列和编码器的输出通过解码器层传递,得到解码器的输出。
  4. 最终线性层:解码器的输出使用全连接(线性)层映射到目标词汇表大小。
  5. 输出:最终输出是一个张量,表示模型对目标序列的预测。
def forward(self, src, tgt):  
        src_mask, tgt_mask = self.generate_mask(src, tgt)  
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))  
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))  
  
        enc_output = src_embedded  
        for enc_layer in self.encoder_layers:  
            enc_output = enc_layer(enc_output, src_mask)  
  
        dec_output = tgt_embedded  
        for dec_layer in self.decoder_layers:  
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)  
  
        output = self.fc(dec_output)  
        return output