【附Jupyter notebook源码】基于Transformer的中英文翻译模型实现

5 阅读4分钟

【附Jupyter notebook源码】基于Transformer的中英文翻译模型实现

Transformer架构自2017年提出以来,已成为自然语言处理领域的主流模型。 Transformer起初就是为了翻译而生,本项目使用Pytorch实现了一个轻量版的中英文翻译模型。

项目整体设计

1. 数据预处理层

文本清洗策略

英文预处理主要完成大小写转换和特殊字符过滤:

def clean_english_text(text: str) -> str:
    text = text.lower().strip()
    text = re.sub(r"[^a-zA-Z0-9?.!,¿]+", " ", text)
    return text

中文处理采用字符级分词方案。考虑到数据集规模有限,未引入jieba等外部分词工具,而是直接将句子拆分为单字序列,通过空格连接后送入模型。这种方式虽然损失了部分词语边界信息,但显著降低了词汇表规模。

词汇表构建

词汇表类Vocabulary维护了两个核心映射:

  • word2idx: 词到索引的映射
  • idx2word: 索引到词的映射

实现时设置了最小词频阈值(MIN_FREQ=2),过滤掉出现次数过低的稀有词,用<unk>标记统一表示。特殊标记包括:

标记含义索引
<pad>填充0
<sos>序列开始1
<eos>序列结束2
<unk>未知词3

2. 模型核心实现

位置编码

Transformer摒弃了RNN的时序处理机制,因此需要显式注入位置信息。本项目采用正弦-余弦位置编码:

pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)

偶数维度使用正弦函数,奇数维度使用余弦函数。这种设计的优势在于:模型可以学习到相对位置关系,且对训练时未见过的序列长度具有泛化能力。

掩码机制

翻译任务中需要处理两类掩码:

源序列掩码(Padding Mask)

用于屏蔽填充位置,防止注意力机制关注无意义的<pad>标记。实现上通过比较输入张量与0值生成布尔掩码。

目标序列掩码(Look-ahead Mask)

解码器在生成第t个词时,只能依赖已生成的t-1个词,不能提前"看到"后续内容。通过生成上三角矩阵实现:

torch.triu(torch.ones(tgt_len, tgt_len), diagonal=1).bool()
模型配置

考虑到数据集规模,采用了精简的模型配置:

  • 词嵌入维度:256(原论文512)
  • 注意力头数:4(原论文8)
  • 编解码器层数:各2层(原论文6层)
  • 前馈网络维度:512(原论文2048)

参数量控制在3M以内,单张T4显卡即可在数分钟内完成训练。

3. 训练策略

优化器与学习率

采用Adam优化器,设置β1=0.9、β2=0.98,这是Transformer原论文推荐的配置。学习率初始值为1e-4,配合ReduceLROnPlateau调度器,当验证损失连续3轮未下降时自动减半。

梯度裁剪

设置梯度裁剪阈值为1.0,防止训练初期因异常样本导致的梯度爆炸问题。

torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRADIENT)
断点续训

通过保存和加载checkpoint实现训练中断后的恢复:

checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': val_loss,
    'en_vocab': en_vocab,
    'zh_vocab': zh_vocab
}
torch.save(checkpoint, MODEL_SAVE_PATH)

⚠️注意词汇表对象也需要一并保存,确保推理时词到索引的映射关系一致。

4. 推理实现

训练完成后,采用贪婪解码(Greedy Decoding)进行翻译。具体流程为:

  1. 将源句子编码为索引序列
  2. <sos>标记作为解码器初始输入
  3. 每轮迭代预测下一个概率最高的词
  4. 当生成<eos>或达到最大长度时终止
def translate(self, src, max_len: int = 50):
    tgt = torch.ones(batch_size, 1).fill_(1).long().to(src.device)
    for _ in range(max_len):
        output = self.forward(src, tgt)
        next_word = output[:, -1, :].argmax(dim=-1, keepdim=True)
        tgt = torch.cat([tgt, next_word], dim=1)
        if (next_word == 2).all():
            break
    return tgt

贪婪解码的局限在于每一步只保留当前最优选择,无法回溯修正。对于要求更高的场景,可替换为束搜索(Beam Search)。

结果

在T4双卡环境下,模型训练50轮耗时约15分钟。验证损失收敛至0.8左右,部分测试样例的翻译结果如下:

英文输入模型输出
Hello!你好。
Thank you.谢谢。
Good morning!早上好!
I love you.我爱你。

对于训练集中出现过的句式,模型能够准确翻译;对于未见过的新组合,则可能出现语法错误或语义偏差,这是小规模模型的固有局限。

⚠️完整代码已开源至GitHub,源码仅一个Jupyter Notebook文件,方便调试和学习!