我正在参加Trae「超级体验官」创意实践征文, 本文所使用的 Trae 免费下载链接: www.trae.ai/?utm_source…
🔥 还在为看不懂论文代码而烦恼吗?
🤖 想不想试试复现经典论文的工作?
🚀 本期推文手把手教你用Pytorch构建一套基于Transformer架构的中英翻译项目!(会附上全部代码)
🎯 你将学到:
- Transformer模型的原理与优势
- 如何使用PyTorch构建Transformer模型
- 从数据预处理到模型训练,一步步实现中英翻译
- 模型评估与优化技巧
📚 适合人群:
- 对深度学习感兴趣的开发者
- 想学习Transformer模型的同学
- 希望提升代码实践能力的朋友
⏰ 本期实战内容较多,建议先收藏再慢慢学习!前面的内容先不会涉及到Trae的体验,后面会有遇到问题后如何用Trea解决的过程
Transformer模型的介绍
Transformer 模型可以说是现今人工智能的基石之一了,它广泛应于NLP和CV领域。在进入代码部分之前,读者们需要先掌握有关于注意力机制和Transformer模型的基础知识。因为笔者精力有限并且学识浅薄,在这里对这些内容就不做过多赘述。我附赠上一些我认为非常优质的教学课程供诸位学习:
- Transformer论文带读:www.bilibili.com/video/BV1pu…
- Transformer论文地址:arxiv.org/abs/1706.03…
- 注意力机制:zh.d2l.ai/chapter_att…
项目开始--基于Transformer构建的中英翻译系统
一. 位置编码实现
在 Transformer 模型中,位置编码(Positional Encoding)的实现是模型理解序列顺序的关键。以下是逐层解析:
1. 位置编码的核心作用
传统RNN通过序列顺序处理自然获得位置信息,而Transformer的并行处理特性需要显式位置编码。位置编码需要满足:
- 唯一性:每个位置有独特编码
- 相对性:能表达位置间的相对关系
- 泛化性:能处理比训练时更长的序列
2. 数学公式解析
原论文采用的正弦/余弦位置编码公式:
其中:
pos:位置索引i:维度索引(0 ≤ i < d_model/2)d_model:模型维度
3. 代码实现精析
步骤分解:
-
初始化位置矩阵
pe = torch.zeros(max_seq_length, d_model)- 创建形状为
[max_seq_length, d_model]的零矩阵
- 创建形状为
-
生成位置索引
position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)- 生成
[0, 1, 2, ..., max_seq_length-1]的位置向量 - 通过
unsqueeze(1)转换为列向量[max_seq_length, 1]
- 生成
-
计算频率项
- 关键数学变换:
- 生成不同频率的缩放因子
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
-
填充位置编码矩阵
pe[:, 0::2] = torch.sin(position * div_term) # 偶数维度 pe[:, 1::2] = torch.cos(position * div_term) # 奇数维度- 交替使用正弦和余弦函数
- 不同维度对应不同频率的波形
-
注册为缓冲区
self.register_buffer('pe', pe.unsqueeze(0))- 将位置编码保存为不参与训练的参数
- 添加批次维度
[1, max_seq_length, d_model]
4. 关键设计思想
| 设计选择 | 优势分析 |
|---|---|
| 正弦/余弦交替 | 允许模型学习相对位置关系(可通过线性变换表示位置偏移) |
| 指数频率衰减 | 不同维度捕获不同范围的位置模式(低频长距离,高频短距离) |
| 固定编码+可学习嵌入 | 既保留确定性的位置信息,又允许模型通过嵌入学习语义特征 |
| 加法融合 | 位置信息与词嵌入线性叠加,保留各自特征的同时实现信息融合 |
5. 可视化理解
假设 d_model=512,不同维度的位置编码呈现:
- 低频维度(小i值):波长较长,捕获全局位置关系
- 高频维度(大i值):波长较短,捕获局部位置关系
6. 技术细节解析
为什么选择正弦函数?
- 相对位置表达:对于固定偏移量k,存在线性变换矩阵M使得:PEpos+k=M⋅PEposPEpos+k=M⋅PEpos
- 无限扩展性:不受预设最大长度限制,可处理任意长度序列
维度为什么要交替?
- 保证每个位置编码的每个维度都有独特的频率组合
- 增强模型对不同位置模式的捕获能力
为什么用加法而不是拼接?
- 保持维度一致性(拼接会改变输入维度)
- 实验表明加法效果更好(信息融合更自然)
7. 与其他位置编码对比
| 编码类型 | 优点 | 缺点 |
|---|---|---|
| 绝对位置编码 | 简单直观 | 无法学习相对位置关系 |
| 相对位置编码 | 显式建模位置关系 | 实现复杂,计算成本高 |
| 旋转位置编码 | 理论性质优秀 | 实现复杂度较高 |
| 学习式位置编码 | 完全可学习 | 无法处理超长序列 |
8. 完整代码实现
import torch
import torch.nn as nn
import numpy as np
import math
# 设置随机种子确保可复现性
torch.manual_seed(42)
np.random.seed(42)
# 1. 位置编码实现
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_length=5000):
super().__init__()
# 创建一个零矩阵,形状为 [max_seq_length, d_model]
pe = torch.zeros(max_seq_length, d_model)
# 创建一个位置索引向量 [0, 1, 2, ..., max_seq_length-1]
# unsqueeze(1) 将形状从 [max_seq_length] 变为 [max_seq_length, 1]
position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
# 创建分母项,形状为 [d_model/2]
# 计算 10000^(2i/d_model) 中的指数项
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
# 使用正弦函数填充偶数位置
pe[:, 0::2] = torch.sin(position * div_term)
# 使用余弦函数填充奇数位置
pe[:, 1::2] = torch.cos(position * div_term)
# 添加批次维度,形状变为 [1, max_seq_length, d_model]
pe = pe.unsqueeze(0)
# 将位置编码注册为缓冲区(不参与训练)
self.register_buffer('pe', pe)
def forward(self, x):
# x 的形状为 [batch_size, seq_length, d_model]
# 返回输入加上位置编码
return x + self.pe[:, :x.size(1)]
二. 多头注意力机制
1. 多头注意力的核心作用
多头注意力是Transformer的核心创新,实现了:
- 并行化特征提取:多个注意力头并行捕捉不同类型的上下文依赖
- 解耦特征空间:将高维空间分解到多个子空间进行独立学习
- 增强模型容量:通过增加头数提升模型表达能力,而不显著增加计算量
2. 数学公式解析
标准缩放点积注意力公式:
多头注意力扩展为:
其中每个头:
3. 代码实现精析
关键步骤分解:
-
维度校验
assert d_model % num_heads == 0- 确保模型维度可均分给各个注意力头
-
线性投影层
self.W_q = nn.Linear(d_model, d_model) # 其他类似- 每个头拥有独立的可学习参数矩阵
- 实现特征空间的可控分解
-
张量重塑
Q.view(batch_size, -1, num_heads, d_k).transpose(1, 2)- 将
[batch, seq_len, d_model]转换为[batch, num_heads, seq_len, d_k] - 通过维度变换实现并行多头计算
- 将
-
缩放点积计算
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)- 缩放因子 dkdk 防止点积数值过大导致softmax梯度消失
-
掩码处理
scores.masked_fill(mask == 0, -1e9)- 在softmax前将非法位置设为极大负值,确保注意力权重趋近0
-
输出融合
output.transpose(1,2).contiguous().view(...) self.W_o(output)- 拼接各头输出
[batch, seq_len, d_model] - 最终线性变换融合多头信息
- 拼接各头输出
4. 关键设计思想
| 设计选择 | 优势分析 |
|---|---|
| 多头并行机制 | 允许模型在不同子空间学习多样化的注意力模式 |
| 共享参数架构 | 通过线性变换参数共享,控制模型复杂度 |
| 维度均分策略 | 保持各头计算量均衡,充分利用硬件并行能力 |
| 残差连接设计 | 后续配合残差连接,缓解梯度消失问题(在EncoderLayer中实现) |
5. 计算过程可视化
6. 技术细节解析
为什么需要缩放因子?
- 当较大时,点积结果方差增大,导致softmax趋向极值分布
- 经验公式:,缩放使方差保持为1
头数如何选择?
- 典型配置: 时用8头, 时用12头
- 头数越多,单头维度越小,需平衡多样性和表达能力
参数共享机制:
- 所有头共享相同的矩阵
- 实际实现中通过单个线性层+维度分割等效多组参数
7. 与其他注意力机制对比
| 注意力类型 | 计算复杂度 | 优势 | 局限性 |
|---|---|---|---|
| 多头注意力 | O(n²·d) | 并行化、多特征空间 | 内存消耗大 |
| 局部注意力 | O(n·k) | 适合长序列 | 丢失全局信息 |
| 稀疏注意力 | O(n√n) | 平衡效率与效果 | 需要特定模式设计 |
| 线性注意力 | O(n) | 理论线性复杂度 | 近似误差积累 |
8. 数学证明(精度稳定性)
定理:缩放因子 能保持精度量级稳定
证明:设 为独立随机变量,元素服从 ,则:
缩放后:
这使得 softmax 输入的方差保持稳定,避免梯度爆炸/消失。
9. 实际应用技巧
-
头数选择策略
- 在8-16头之间探索,保持
- 可通过注意力头重要性分析裁剪冗余头
-
混合精度训练
with torch.autocast(device_type='cuda', dtype=torch.float16): attn_output = self.scaled_dot_product_attention(Q, K, V)- 利用Tensor Core加速矩阵运算
-
注意力可视化
attention_maps = torch.stack([head.attention_weights for head in self.heads])- 分析不同头关注的语言现象(语法/指代/语义等)
-
高效实现优化
- 使用Flash Attention V2等优化算法,降低显存占用
通过这种分而治之的架构设计,多头注意力机制使Transformer能够:
- 同时关注不同位置的多种依赖关系
- 在长距离依赖建模上超越RNN的序列处理能力
- 为后续的残差连接和层归一化提供丰富特征表示
这种设计已成为现代深度学习架构的标准组件,不仅应用于NLP,在计算机视觉、语音处理等领域也展现出强大威力。
10. 完整代码实现
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
# 确保 d_model 可以被 num_heads 整除
assert d_model % num_heads == 0
# 保存配置参数
self.d_model = d_model # 模型的维度
self.num_heads = num_heads # 注意力头的数量
self.d_k = d_model // num_heads # 每个注意力头的维度
# 创建四个线性变换层
# 分别用于转换 查询(Q)、键(K)、值(V) 和 输出
self.W_q = nn.Linear(d_model, d_model) # 查询的线性变换
self.W_k = nn.Linear(d_model, d_model) # 键的线性变换
self.W_v = nn.Linear(d_model, d_model) # 值的线性变换
self.W_o = nn.Linear(d_model, d_model) # 输出的线性变换
def scaled_dot_product_attention(self, Q, K, V, mask=None):
# 计算注意力分数:Q 和 K 的矩阵乘法,并除以缩放因子
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# 如果提供了掩码,将掩码位置的值设为负无穷
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 对分数进行 softmax 归一化,得到注意力权重
attention_weights = torch.softmax(scores, dim=-1)
# 将注意力权重与值相乘,得到输出
output = torch.matmul(attention_weights, V)
return output
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 对 Q、K、V 进行线性变换并重塑维度
# 从 [batch_size, seq_len, d_model] 变为 [batch_size, num_heads, seq_len, d_k]
Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 计算注意力
output = self.scaled_dot_product_attention(Q, K, V, mask)
# 重组输出维度
# 从 [batch_size, num_heads, seq_len, d_k]
# 变回 [batch_size, seq_len, d_model]
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# 通过输出线性层返回结果
return self.W_o(output)