Transformer 模型结构
根据前面介绍的所有组件,结合起来,就是一个完整的 Transformer 结构了。
上图为论文《Attention is all you need》原文配图,LayerNorm 是放在 Attention 层之后,也就是“Post-Norm”结构,但是在其发布的源码中,LayerNormer 是放在 Attention 层之前,也就是“Pre-Norm”。实际中,Pre-Norm 结构可以使 loss 更稳定,所以目前 LLM 一般采用“Pre-Norm”,即输入先归一化,Attention 层输入更稳定。“Post-Norm”的话,Attention 输出可能很大。
class Transformer(nn.Module):
'''整体模型'''
def __init__(self, args):
super().__init__()
# 必须输入词表大小和 block size
assert args.vocab_size is not None
assert args.block_size is not None
self.args = args
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(args.vocab_size, args.n_embd),
wpe = PositionalEncoding(args),
drop = nn.Dropout(args.dropout),
encoder = Encoder(args),
decoder = Decoder(args),
))
# 最后的线性层,输入是 n_embd,输出是词表大小
self.lm_head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
# 初始化所有的权重
self.apply(self._init_weights)
# 查看所有参数的数量
print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
'''统计所有参数的数量'''
def get_num_params(self, non_embedding=False):
# non_embedding: 是否统计 embedding 的参数
n_params = sum(p.numel() for p in self.parameters())
# 如果不统计 embedding 的参数,就减去
if non_embedding:
n_params -= self.transformer.wte.weight.numel()
return n_params
'''初始化权重'''
def _init_weights(self, module):
# 线性层和 Embedding 层初始化为正则分布
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
'''前向计算函数'''
def forward(self, idx, targets=None):
# 输入为 idx,维度为 (batch size, sequence length, 1);targets 为目标序列,用于计算 loss
device = idx.device
b, t = idx.size()
assert t <= self.args.block_size, f"不能计算该序列,该序列长度为 {t}, 最大序列长度只有 {self.args.block_size}"
# 通过 self.transformer
# 首先将输入 idx 通过 Embedding 层,得到维度为 (batch size, sequence length, n_embd)
print("idx",idx.size())
# 通过 Embedding 层
tok_emb = self.transformer.wte(idx)
print("tok_emb",tok_emb.size())
# 然后通过位置编码
pos_emb = self.transformer.wpe(tok_emb)
# 再进行 Dropout
x = self.transformer.drop(pos_emb)
# 然后通过 Encoder
print("x after wpe:",x.size())
enc_out = self.transformer.encoder(x)
print("enc_out:",enc_out.size())
# 再通过 Decoder
x = self.transformer.decoder(x, enc_out)
print("x after decoder:",x.size())
if targets is not None:
# 训练阶段,如果我们给了 targets,就计算 loss
# 先通过最后的 Linear 层,得到维度为 (batch size, sequence length, vocab size)
logits = self.lm_head(x)
# 再跟 targets 计算交叉熵
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
# 推理阶段,我们只需要 logits,loss 为 None
# 取 -1 是只取序列中的最后一个作为输出
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
loss = None
return logits, loss
=== Transformer整体结构 ===
1. Transformer的组成部分:
输入层:
- wte: 词嵌入层
- wpe: 位置编码层
- drop: Dropout层
编码器:
- encoder: 编码器
解码器:
- decoder: 解码器
输出层:
- lm_head: 线性层(输出词表概率)
2. 数据流动:
输入 idx (batch_size, seq_len)
↓
wte: 词嵌入 (batch_size, seq_len, n_embd)
↓
wpe: 位置编码 (batch_size, seq_len, n_embd)
↓
drop: Dropout (batch_size, seq_len, n_embd)
↓
encoder: 编码器 (batch_size, seq_len, n_embd)
↓
decoder: 解码器 (batch_size, seq_len, n_embd)
↓
lm_head: 线性层 (batch_size, seq_len, vocab_size)
↓
logits (batch_size, seq_len, vocab_size)
=== __init__方法详解 ===
1. 参数检查:
assert args.vocab_size is not None
assert args.block_size is not None
作用:
- 确保vocab_size(词表大小)已设置
- 确保block_size(最大序列长度)已设置
- 如果没有设置,会报错
2. 创建组件:
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(args.vocab_size, args.n_embd),
wpe = PositionalEncoding(args),
drop = nn.Dropout(args.dropout),
encoder = Encoder(args),
decoder = Decoder(args),
))
解释:
- wte: 词嵌入层
* 输入: token索引 (batch_size, seq_len)
* 输出: 词向量 (batch_size, seq_len, n_embd)
* 参数数量: vocab_size × n_embd
- wpe: 位置编码层
* 输入: 词向量 (batch_size, seq_len, n_embd)
* 输出: 添加位置编码后的向量 (batch_size, seq_len, n_embd)
- drop: Dropout层
* 输入: 向量 (batch_size, seq_len, n_embd)
* 输出: Dropout后的向量 (batch_size, seq_len, n_embd)
* 作用: 防止过拟合
- encoder: 编码器
* 输入: 向量 (batch_size, seq_len, n_embd)
* 输出: 编码后的向量 (batch_size, seq_len, n_embd)
- decoder: 解码器
* 输入: 向量 + 编码器输出
* 输出: 解码后的向量 (batch_size, seq_len, n_embd)
3. 输出层:
self.lm_head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
解释:
- 输入: 解码器输出 (batch_size, seq_len, n_embd)
- 输出: 词表概率 (batch_size, seq_len, vocab_size)
- bias=False: 不使用偏置
- 参数数量: n_embd × vocab_size
4. 权重初始化:
self.apply(self._init_weights)
解释:
- 对所有线性层和Embedding层进行初始化
- 使用正态分布初始化
- mean=0.0, std=0.02
5. 参数统计:
print('number of parameters: %.2fM' % (self.get_num_params()/1e6,))
解释:
- 统计所有参数的数量
- 除以1e6,转换为百万(M)
- 例如: 10M = 1000万参数
=== forward方法详解 ===
1. 输入参数:
def forward(self, idx, targets=None):
参数:
- idx: 输入序列
* 形状: (batch_size, seq_len)
* 内容: token索引
- targets: 目标序列(可选)
* 形状: (batch_size, seq_len)
* 内容: 目标token索引
* 用途: 计算loss
2. 参数检查:
device = idx.device
b, t = idx.size()
assert t <= self.args.block_size
解释:
- device: 获取设备(CPU或GPU)
- b: batch_size
- t: seq_len(序列长度)
- 检查序列长度是否超过最大长度
3. 词嵌入:
tok_emb = self.transformer.wte(idx)
数据变化:
- 输入: idx (batch_size, seq_len)
- 输出: tok_emb (batch_size, seq_len, n_embd)
例子:
- idx: [[1, 2, 3], [4, 5, 6]]
- tok_emb: [[[0.1, 0.2, ...], [0.3, 0.4, ...], [0.5, 0.6, ...]], ...]
- 每个token索引转换为一个n_embd维的向量
4. 位置编码:
pos_emb = self.transformer.wpe(tok_emb)
数据变化:
- 输入: tok_emb (batch_size, seq_len, n_embd)
- 输出: pos_emb (batch_size, seq_len, n_embd)
作用:
- 给每个位置添加位置信息
- tok_emb + 位置编码 = pos_emb
5. Dropout:
x = self.transformer.drop(pos_emb)
数据变化:
- 输入: pos_emb (batch_size, seq_len, n_embd)
- 输出: x (batch_size, seq_len, n_embd)
作用:
- 随机丢弃一些神经元
- 防止过拟合
6. 编码器:
enc_out = self.transformer.encoder(x)
数据变化:
- 输入: x (batch_size, seq_len, n_embd)
- 输出: enc_out (batch_size, seq_len, n_embd)
作用:
- 编码输入序列
- 提取特征
7. 解码器:
x = self.transformer.decoder(x, enc_out)
数据变化:
- 输入1: x (batch_size, seq_len, n_embd)
- 输入2: enc_out (batch_size, seq_len, n_embd)
- 输出: x (batch_size, seq_len, n_embd)
作用:
- 解码输入序列
- 结合编码器输出
8. 输出层:
if targets is not None:
# 训练阶段
logits = self.lm_head(x)
loss = F.cross_entropy(...)
else:
# 推理阶段
logits = self.lm_head(x[:, [-1], :])
loss = None
训练阶段:
- 输入: x (batch_size, seq_len, n_embd)
- 输出: logits (batch_size, seq_len, vocab_size)
- 计算loss: 交叉熵损失
推理阶段:
- 输入: x[:, [-1], :] (batch_size, 1, n_embd)
- 只取最后一个时间步
- 输出: logits (batch_size, 1, vocab_size)
- loss = None
9. 返回值:
return logits, loss
logits: 词表概率
loss: 损失(训练阶段有值,推理阶段为None)
=== 完整的数据流动示例 ===
1. 参数:
batch_size: 2
seq_len: 5
vocab_size: 10000
n_embd: 512
2. 数据流动:
步骤1: 输入
- idx: (2, 5)
- 例如: [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]
步骤2: 词嵌入
- tok_emb: (2, 5, 512)
- 每个token索引转换为512维向量
步骤3: 位置编码
- pos_emb: (2, 5, 512)
- tok_emb + 位置编码
步骤4: Dropout
- x: (2, 5, 512)
- 随机丢弃一些神经元
步骤5: 编码器
- enc_out: (2, 5, 512)
- 编码输入序列
步骤6: 解码器
- x: (2, 5, 512)
- 解码输入序列 + 编码器输出
步骤7: 输出层
- logits: (2, 5, 10000)
- 每个位置输出词表概率
3. 训练 vs 推理:
训练阶段:
- 输出: logits (2, 5, 10000)
- loss: 交叉熵损失
- 用途: 更新模型参数
推理阶段:
- 输出: logits (2, 1, 10000)
- loss: None
- 用途: 生成下一个词
4. 关键点:
- wte: 词嵌入层,将token索引转换为向量
- wpe: 位置编码层,添加位置信息
- encoder: 编码器,提取特征
- decoder: 解码器,生成输出
- lm_head: 输出层,输出词表概率
- 训练时: 输出所有位置的logits,计算loss
- 推理时: 只输出最后一个位置的logits,用于生成
输入 idx (batch_size, seq_len)
↓
wte: 词嵌入 (batch_size, seq_len, n_embd)
↓
wpe: 位置编码 (batch_size, seq_len, n_embd)
↓
drop: Dropout (batch_size, seq_len, n_embd)
↓
encoder: 编码器 (batch_size, seq_len, n_embd)
↓
decoder: 解码器 (batch_size, seq_len, n_embd)
↓
lm_head: 线性层 (batch_size, seq_len, vocab_size)
↓
logits (batch_size, seq_len, vocab_size)
关键组件
训练 vs 推理
各模块作用
1、词嵌入层:将 token 索引转换为向量
2、位置编码:添加位置信息
3、编码器:提取特征
4、解码器:生成输出
5、输出层:输出词表概率
6、权重初始化:正态分布