【附Jupyter notebook源码】基于Transformer的中英文翻译模型实现
Transformer架构自2017年提出以来,已成为自然语言处理领域的主流模型。 Transformer起初就是为了翻译而生,本项目使用Pytorch实现了一个轻量版的中英文翻译模型。
项目整体设计
1. 数据预处理层
文本清洗策略
英文预处理主要完成大小写转换和特殊字符过滤:
def clean_english_text(text: str) -> str:
text = text.lower().strip()
text = re.sub(r"[^a-zA-Z0-9?.!,¿]+", " ", text)
return text
中文处理采用字符级分词方案。考虑到数据集规模有限,未引入jieba等外部分词工具,而是直接将句子拆分为单字序列,通过空格连接后送入模型。这种方式虽然损失了部分词语边界信息,但显著降低了词汇表规模。
词汇表构建
词汇表类Vocabulary维护了两个核心映射:
word2idx: 词到索引的映射idx2word: 索引到词的映射
实现时设置了最小词频阈值(MIN_FREQ=2),过滤掉出现次数过低的稀有词,用<unk>标记统一表示。特殊标记包括:
| 标记 | 含义 | 索引 |
|---|---|---|
<pad> | 填充 | 0 |
<sos> | 序列开始 | 1 |
<eos> | 序列结束 | 2 |
<unk> | 未知词 | 3 |
2. 模型核心实现
位置编码
Transformer摒弃了RNN的时序处理机制,因此需要显式注入位置信息。本项目采用正弦-余弦位置编码:
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
偶数维度使用正弦函数,奇数维度使用余弦函数。这种设计的优势在于:模型可以学习到相对位置关系,且对训练时未见过的序列长度具有泛化能力。
掩码机制
翻译任务中需要处理两类掩码:
源序列掩码(Padding Mask)
用于屏蔽填充位置,防止注意力机制关注无意义的<pad>标记。实现上通过比较输入张量与0值生成布尔掩码。
目标序列掩码(Look-ahead Mask)
解码器在生成第t个词时,只能依赖已生成的t-1个词,不能提前"看到"后续内容。通过生成上三角矩阵实现:
torch.triu(torch.ones(tgt_len, tgt_len), diagonal=1).bool()
模型配置
考虑到数据集规模,采用了精简的模型配置:
- 词嵌入维度:256(原论文512)
- 注意力头数:4(原论文8)
- 编解码器层数:各2层(原论文6层)
- 前馈网络维度:512(原论文2048)
参数量控制在3M以内,单张T4显卡即可在数分钟内完成训练。
3. 训练策略
优化器与学习率
采用Adam优化器,设置β1=0.9、β2=0.98,这是Transformer原论文推荐的配置。学习率初始值为1e-4,配合ReduceLROnPlateau调度器,当验证损失连续3轮未下降时自动减半。
梯度裁剪
设置梯度裁剪阈值为1.0,防止训练初期因异常样本导致的梯度爆炸问题。
torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRADIENT)
断点续训
通过保存和加载checkpoint实现训练中断后的恢复:
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': val_loss,
'en_vocab': en_vocab,
'zh_vocab': zh_vocab
}
torch.save(checkpoint, MODEL_SAVE_PATH)
⚠️注意词汇表对象也需要一并保存,确保推理时词到索引的映射关系一致。
4. 推理实现
训练完成后,采用贪婪解码(Greedy Decoding)进行翻译。具体流程为:
- 将源句子编码为索引序列
- 以
<sos>标记作为解码器初始输入 - 每轮迭代预测下一个概率最高的词
- 当生成
<eos>或达到最大长度时终止
def translate(self, src, max_len: int = 50):
tgt = torch.ones(batch_size, 1).fill_(1).long().to(src.device)
for _ in range(max_len):
output = self.forward(src, tgt)
next_word = output[:, -1, :].argmax(dim=-1, keepdim=True)
tgt = torch.cat([tgt, next_word], dim=1)
if (next_word == 2).all():
break
return tgt
贪婪解码的局限在于每一步只保留当前最优选择,无法回溯修正。对于要求更高的场景,可替换为束搜索(Beam Search)。
结果
在T4双卡环境下,模型训练50轮耗时约15分钟。验证损失收敛至0.8左右,部分测试样例的翻译结果如下:
| 英文输入 | 模型输出 |
|---|---|
| Hello! | 你好。 |
| Thank you. | 谢谢。 |
| Good morning! | 早上好! |
| I love you. | 我爱你。 |
对于训练集中出现过的句式,模型能够准确翻译;对于未见过的新组合,则可能出现语法错误或语义偏差,这是小规模模型的固有局限。
⚠️完整代码已开源至GitHub,源码仅一个Jupyter Notebook文件,方便调试和学习!