Seq2Seq(Sequence-to-Sequence)是一种用于将一个序列映射到另一个序列的深度学习架构,广泛应用于机器翻译、语音识别、文本摘要、对话系统等任务。其核心思想是使用两个神经网络:编码器(Encoder) 和解码器(Decoder) 。
下面分别介绍 Seq2Seq 的训练机制和推理机制:
一、训练机制(Training)
1. 基本结构
- 编码器(Encoder) :通常是一个 RNN(如 LSTM 或 GRU),也可以是 Transformer 的编码器。它将输入序列 x=(x1,x2,...,xT)x=(x1,x2,...,xT) 编码为一个上下文向量(context vector) cc ,通常是最后一个隐藏状态或所有隐藏状态的加权组合。
- 解码器(Decoder) :另一个 RNN(或 Transformer 解码器),以编码器输出的上下文向量 cc 作为初始状态,并逐步生成目标序列 y=(y1,y2,...,yT′)y=(y1,y2,...,yT′) 。
2. 训练方式:Teacher Forcing
- 在训练过程中,解码器在每个时间步的输入不是自己上一步预测的词,而是真实的目标序列中的前一个词(ground truth) 。这称为 Teacher Forcing。
- 例如:要生成 “I love you”,在生成 “love” 时,输入是 “I”(真实标签),而不是模型自己预测的可能错误词。
- 目标函数:通常使用交叉熵损失(Cross-Entropy Loss) ,对每个时间步的预测分布与真实标签计算损失,再求和或平均。
3. 注意力机制(可选但常用)
- 原始 Seq2Seq 将整个输入压缩成一个固定长度的向量,对长序列效果差。
- 引入 注意力机制(Attention) 后,解码器在每一步动态关注输入序列的不同部分,显著提升性能(如 Bahdanau Attention、Luong Attention)。
- Transformer 架构则完全基于自注意力(Self-Attention)和交叉注意力(Cross-Attention)实现 Seq2Seq。
二、推理机制(Inference / Decoding)
推理阶段没有真实目标序列可用,因此不能使用 Teacher Forcing。解码器必须自回归地(autoregressively) 生成输出:
1. 自回归生成
- 初始输入通常是起始符(如
<sos>)。 - 每一步将上一步模型预测的 token 作为当前输入。
- 重复此过程,直到生成结束符(如
<eos>)或达到最大长度。
2. 解码策略
由于每一步都依赖前一步的输出,存在多种解码策略来平衡生成质量与计算效率:
表格
| 策略 | 描述 | 优缺点 |
|---|---|---|
| 贪心搜索(Greedy Search) | 每步选择概率最高的 token | 快,但容易陷入局部最优,生成结果可能不连贯 |
| 束搜索(Beam Search) | 维护 top-k 个候选序列,每步扩展并保留得分最高的 k 个 | 质量优于贪心,但计算开销大;k 越大越接近 |
总结
- 训练:编码器读入源序列 → 生成上下文向量 → 解码器在 Teacher Forcing 下学习生成目标序列 → 用交叉熵损失优化。
- 推理:编码器处理输入 → 解码器自回归生成输出,配合束搜索等策略提升质量。