Transformer训练验证一个翻译模型

131 阅读3分钟

数据准备

把输入或目标输入语句转换为Transformer模型的格式?这里假设有两对语句。假设输入语句最大长度为5,不足的语句用0填补。目标语句的最大长度为5,不足的语句用0填补。 开始、结束分别加上、对应索引为8和9,表示语句的开始与结束。

import torch  
from transformer import Transformer
import torch.nn as nn  
import torch.optim as optim

# 中文->英文
en = [['<b>', 'i', ' ', 'like', ' ', 'you', '<e>'], 
      ['<b>', 'i', ' ', 'hate', ' ', 'you', '<e>'],
      ['<b>', 'i', ' ', 'love', ' ', 'you', '<e>'],
      ['<b>', 'he', ' ', 'like', ' ', 'you', '<e>']]
zh = [['<b>', '我', '喜欢', '你', '<p>', '<p>', '<e>'], 
      ['<b>', '我', '讨厌', '你', '<p>', '<p>', '<e>'], 
      ['<b>', '我', '爱', '你', '<p>', '<p>', '<e>'],
      ['<b>', '他', '喜欢', '你', '<p>', '<p>', '<e>']]

# 样本字典与数据集
en_vocab_i2t = ['<p>', 'like', 'you', 'hate', 'he', ' ', 'love', '<b>', '<e>', 'i']
en_vocab_t2i = {'<p>': 0,  'i': 1, 'like': 2, 'you': 3, 'hate': 4, 'he': 5, ' ': 6, 'love': 7
    , '<b>': 8, '<e>': 9}
zh_vocab_i2t = ['<p>', '我', '喜欢', '你', '讨厌', '爱', ' ', '他', '<b>', '<e>']
zh_vocab_t2i = {'<p>': 0, '我': 1, '喜欢': 2, '你': 3, '讨厌': 4, '爱': 5, ' ': 6, '他': 7
    , '<b>': 8, '<e>': 9}

def process(en, zh):
    en_idx = [[en_vocab_t2i[token] for token in line] for line in en]
    zh_idx = [[zh_vocab_t2i[token] for token in line] for line in zh]
    return torch.tensor(en_idx), torch.tensor(zh_idx)

# 模型参数配置
src_vocab_size = len(zh_vocab_i2t)
tgt_vocab_size = len(en_vocab_i2t)
d_model = 512  
num_heads = 8  
num_layers = 6  
d_ff = 2048  
max_seq_length = 7  
dropout = 0.1  

# 生成样本数据
tgt_data, src_data = process(en, zh)

创建Transformer实例

transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout) 

训练模型

criterion = nn.CrossEntropyLoss(ignore_index=0)  
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)  

def train():
    transformer.train()
    for epoch in range(50):
        optimizer.zero_grad()
        output = transformer(src_data, tgt_data[:, :-1])
        loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))  
        loss.backward()  
        optimizer.step()  
        print(f"Epoch: {epoch+1}, Loss: {loss.item()}")
train()
Epoch: 47, Loss: 0.0037290186155587435
Epoch: 48, Loss: 0.0039988732896745205
Epoch: 49, Loss: 0.003924884367734194
Epoch: 50, Loss: 0.0038314256817102432

测试翻译模型

翻译方法

首先在transformer类实现一个翻译方法,输入源序列、目标序列、结束符,输出翻译列表

def translate(self, src, tgt, end):
        src_mask, _ = self.generate_mask(src, tgt)  
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))  
  
        enc_output = src_embedded  
        for enc_layer in self.encoder_layers:  
            enc_output = enc_layer(enc_output,src_mask)  
        
        res = []
        def goNo(mTgt):
          _, tgt_mask = self.generate_mask(src, mTgt)  
          tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(mTgt)))
          dec_output = tgt_embedded
          for dec_layer in self.decoder_layers:  
              dec_output = dec_layer(dec_output, enc_output,src_mask,tgt_mask)  
    
          output = self.fc(dec_output[:, -1]) # 取数组内数组最后一位,并合并一维
          pre = output.contiguous().view(-1)
          pred = torch.argmax(pre,dim=-1)

          if pred != end  and len(res) < 50:
            res.append(pred)
            mTgt = torch.cat([mTgt,pred.clone().detach().unsqueeze(0).unsqueeze(0)], dim=1)
            goNo(mTgt)
        
        goNo(tgt)
        return res  

测试

def test():
    tgt = torch.tensor(8).unsqueeze(0).unsqueeze(0)
    x = ['<b>', '我', '喜欢','你','<p>', '<p>', '<e>']
    x = torch.tensor([zh_vocab_t2i[token] for token in x]).unsqueeze(0)
    res = []

    resList = transformer.translate(x,tgt,9)

    for index in range(len(resList)):
        res.append(en_vocab_i2t[resList[index]])

    print(' '.join(res))  
    print(resList)  

test()
like love you love hate
[tensor(1), tensor(6), tensor(2), tensor(6), tensor(3)]

总结

由于样本数据,训练方法,模型本身等问题,可能翻译的不太理想。这里展示Transformer利用注意力机制来提高模型训练速度的模型,训练测试的完整流程。

不过作为demo版,我们就不要要求太高了:),这里主要还是给大家展示一下代码的实现过程。

这次手动实现Transformer的过程确实让我踩了不少的坑,但也随之弄懂了许多论文里没有提到的技术细节,还是收获很多的。看来深入理解一个模型最好的方法还是动手实现一下,毕竟“纸上得来终觉浅”。

总之还是要多动手,光看论文的话很容易陷入“以为自己懂了,但其实基本没懂”的困境。