从0开始学AI:Transformer,原来是这回事!

0 阅读7分钟

Transformer 模型结构

根据前面介绍的所有组件,结合起来,就是一个完整的 Transformer 结构了。

图 Transformer模型结构

上图为论文《Attention is all you need》原文配图,LayerNorm 是放在 Attention 层之后,也就是“Post-Norm”结构,但是在其发布的源码中,LayerNormer 是放在 Attention 层之前,也就是“Pre-Norm”。实际中,Pre-Norm 结构可以使 loss 更稳定,所以目前 LLM 一般采用“Pre-Norm”,即输入先归一化,Attention 层输入更稳定。“Post-Norm”的话,Attention 输出可能很大。

class Transformer(nn.Module):
   '''整体模型'''
    def __init__(self, args):
        super().__init__()
        # 必须输入词表大小和 block size
        assert args.vocab_size is not None
        assert args.block_size is not None
        self.args = args
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(args.vocab_size, args.n_embd),
            wpe = PositionalEncoding(args),
            drop = nn.Dropout(args.dropout),
            encoder = Encoder(args),
            decoder = Decoder(args),
        ))
        # 最后的线性层,输入是 n_embd,输出是词表大小
        self.lm_head = nn.Linear(args.n_embd, args.vocab_size, bias=False)

        # 初始化所有的权重
        self.apply(self._init_weights)

        # 查看所有参数的数量
        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    '''统计所有参数的数量'''
    def get_num_params(self, non_embedding=False):
        # non_embedding: 是否统计 embedding 的参数
        n_params = sum(p.numel() for p in self.parameters())
        # 如果不统计 embedding 的参数,就减去
        if non_embedding:
            n_params -= self.transformer.wte.weight.numel()
        return n_params

    '''初始化权重'''
    def _init_weights(self, module):
        # 线性层和 Embedding 层初始化为正则分布
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    '''前向计算函数'''
    def forward(self, idx, targets=None):
        # 输入为 idx,维度为 (batch size, sequence length, 1);targets 为目标序列,用于计算 loss
        device = idx.device
        b, t = idx.size()
        assert t <= self.args.block_size, f"不能计算该序列,该序列长度为 {t}, 最大序列长度只有 {self.args.block_size}"

        # 通过 self.transformer
        # 首先将输入 idx 通过 Embedding 层,得到维度为 (batch size, sequence length, n_embd)
        print("idx",idx.size())
        # 通过 Embedding 层
        tok_emb = self.transformer.wte(idx)
        print("tok_emb",tok_emb.size())
        # 然后通过位置编码
        pos_emb = self.transformer.wpe(tok_emb) 
        # 再进行 Dropout
        x = self.transformer.drop(pos_emb)
        # 然后通过 Encoder
        print("x after wpe:",x.size())
        enc_out = self.transformer.encoder(x)
        print("enc_out:",enc_out.size())
        # 再通过 Decoder
        x = self.transformer.decoder(x, enc_out)
        print("x after decoder:",x.size())

        if targets is not None:
            # 训练阶段,如果我们给了 targets,就计算 loss
            # 先通过最后的 Linear 层,得到维度为 (batch size, sequence length, vocab size)
            logits = self.lm_head(x)
            # 再跟 targets 计算交叉熵
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            # 推理阶段,我们只需要 logits,loss 为 None
            # 取 -1 是只取序列中的最后一个作为输出
            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            loss = None

        return logits, loss

=== Transformer整体结构 ===

1. Transformer的组成部分:
   
   输入层:
   - wte: 词嵌入层
   - wpe: 位置编码层
   - drop: Dropout层
   
   编码器:
   - encoder: 编码器
   
   解码器:
   - decoder: 解码器
   
   输出层:
   - lm_head: 线性层(输出词表概率)

2. 数据流动:
   
   输入 idx (batch_size, seq_len)
   ↓
   wte: 词嵌入 (batch_size, seq_len, n_embd)
   ↓
   wpe: 位置编码 (batch_size, seq_len, n_embd)
   ↓
   drop: Dropout (batch_size, seq_len, n_embd)
   ↓
   encoder: 编码器 (batch_size, seq_len, n_embd)
   ↓
   decoder: 解码器 (batch_size, seq_len, n_embd)
   ↓
   lm_head: 线性层 (batch_size, seq_len, vocab_size)
   ↓
   logits (batch_size, seq_len, vocab_size)
=== __init__方法详解 ===

1. 参数检查:
   
   assert args.vocab_size is not None
   assert args.block_size is not None
   
   作用:
   - 确保vocab_size(词表大小)已设置
   - 确保block_size(最大序列长度)已设置
   - 如果没有设置,会报错

2. 创建组件:
   
   self.transformer = nn.ModuleDict(dict(
       wte = nn.Embedding(args.vocab_size, args.n_embd),
       wpe = PositionalEncoding(args),
       drop = nn.Dropout(args.dropout),
       encoder = Encoder(args),
       decoder = Decoder(args),
   ))
   
   解释:
   - wte: 词嵌入层
     * 输入: token索引 (batch_size, seq_len)
     * 输出: 词向量 (batch_size, seq_len, n_embd)
     * 参数数量: vocab_size × n_embd
   
   - wpe: 位置编码层
     * 输入: 词向量 (batch_size, seq_len, n_embd)
     * 输出: 添加位置编码后的向量 (batch_size, seq_len, n_embd)
   
   - drop: Dropout层
     * 输入: 向量 (batch_size, seq_len, n_embd)
     * 输出: Dropout后的向量 (batch_size, seq_len, n_embd)
     * 作用: 防止过拟合
   
   - encoder: 编码器
     * 输入: 向量 (batch_size, seq_len, n_embd)
     * 输出: 编码后的向量 (batch_size, seq_len, n_embd)
   
   - decoder: 解码器
     * 输入: 向量 + 编码器输出
     * 输出: 解码后的向量 (batch_size, seq_len, n_embd)

3. 输出层:
   
   self.lm_head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
   
   解释:
   - 输入: 解码器输出 (batch_size, seq_len, n_embd)
   - 输出: 词表概率 (batch_size, seq_len, vocab_size)
   - bias=False: 不使用偏置
   - 参数数量: n_embd × vocab_size

4. 权重初始化:
   
   self.apply(self._init_weights)
   
   解释:
   - 对所有线性层和Embedding层进行初始化
   - 使用正态分布初始化
   - mean=0.0, std=0.02

5. 参数统计:
   
   print('number of parameters: %.2fM' % (self.get_num_params()/1e6,))
   
   解释:
   - 统计所有参数的数量
   - 除以1e6,转换为百万(M)
   - 例如: 10M = 1000万参数
=== forward方法详解 ===

1. 输入参数:
   
   def forward(self, idx, targets=None):
   
   参数:
   - idx: 输入序列
     * 形状: (batch_size, seq_len)
     * 内容: token索引
   - targets: 目标序列(可选)
     * 形状: (batch_size, seq_len)
     * 内容: 目标token索引
     * 用途: 计算loss

2. 参数检查:
   
   device = idx.device
   b, t = idx.size()
   assert t <= self.args.block_size
   
   解释:
   - device: 获取设备(CPU或GPU)
   - b: batch_size
   - t: seq_len(序列长度)
   - 检查序列长度是否超过最大长度

3. 词嵌入:
   
   tok_emb = self.transformer.wte(idx)
   
   数据变化:
   - 输入: idx (batch_size, seq_len)
   - 输出: tok_emb (batch_size, seq_len, n_embd)
   
   例子:
   - idx: [[1, 2, 3], [4, 5, 6]]
   - tok_emb: [[[0.1, 0.2, ...], [0.3, 0.4, ...], [0.5, 0.6, ...]], ...]
   - 每个token索引转换为一个n_embd维的向量

4. 位置编码:
   
   pos_emb = self.transformer.wpe(tok_emb)
   
   数据变化:
   - 输入: tok_emb (batch_size, seq_len, n_embd)
   - 输出: pos_emb (batch_size, seq_len, n_embd)
   
   作用:
   - 给每个位置添加位置信息
   - tok_emb + 位置编码 = pos_emb

5. Dropout:
   
   x = self.transformer.drop(pos_emb)
   
   数据变化:
   - 输入: pos_emb (batch_size, seq_len, n_embd)
   - 输出: x (batch_size, seq_len, n_embd)
   
   作用:
   - 随机丢弃一些神经元
   - 防止过拟合

6. 编码器:
   
   enc_out = self.transformer.encoder(x)
   
   数据变化:
   - 输入: x (batch_size, seq_len, n_embd)
   - 输出: enc_out (batch_size, seq_len, n_embd)
   
   作用:
   - 编码输入序列
   - 提取特征

7. 解码器:
   
   x = self.transformer.decoder(x, enc_out)
   
   数据变化:
   - 输入1: x (batch_size, seq_len, n_embd)
   - 输入2: enc_out (batch_size, seq_len, n_embd)
   - 输出: x (batch_size, seq_len, n_embd)
   
   作用:
   - 解码输入序列
   - 结合编码器输出

8. 输出层:
   
   if targets is not None:
       # 训练阶段
       logits = self.lm_head(x)
       loss = F.cross_entropy(...)
   else:
       # 推理阶段
       logits = self.lm_head(x[:, [-1], :])
       loss = None
   
   训练阶段:
   - 输入: x (batch_size, seq_len, n_embd)
   - 输出: logits (batch_size, seq_len, vocab_size)
   - 计算loss: 交叉熵损失
   
   推理阶段:
   - 输入: x[:, [-1], :] (batch_size, 1, n_embd)
   - 只取最后一个时间步
   - 输出: logits (batch_size, 1, vocab_size)
   - loss = None

9. 返回值:
   
   return logits, loss
   
   logits: 词表概率
   loss: 损失(训练阶段有值,推理阶段为None)
=== 完整的数据流动示例 ===

1. 参数:
   batch_size: 2
   seq_len: 5
   vocab_size: 10000
   n_embd: 512

2. 数据流动:
   
   步骤1: 输入
   - idx: (2, 5)
   - 例如: [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]
   
   步骤2: 词嵌入
   - tok_emb: (2, 5, 512)
   - 每个token索引转换为512维向量
   
   步骤3: 位置编码
   - pos_emb: (2, 5, 512)
   - tok_emb + 位置编码
   
   步骤4: Dropout
   - x: (2, 5, 512)
   - 随机丢弃一些神经元
   
   步骤5: 编码器
   - enc_out: (2, 5, 512)
   - 编码输入序列
   
   步骤6: 解码器
   - x: (2, 5, 512)
   - 解码输入序列 + 编码器输出
   
   步骤7: 输出层
   - logits: (2, 5, 10000)
   - 每个位置输出词表概率

3. 训练 vs 推理:
   
   训练阶段:
   - 输出: logits (2, 5, 10000)
   - loss: 交叉熵损失
   - 用途: 更新模型参数
   
   推理阶段:
   - 输出: logits (2, 1, 10000)
   - loss: None
   - 用途: 生成下一个词

4. 关键点:
   
   - wte: 词嵌入层,将token索引转换为向量
   - wpe: 位置编码层,添加位置信息
   - encoder: 编码器,提取特征
   - decoder: 解码器,生成输出
   - lm_head: 输出层,输出词表概率
   
   - 训练时: 输出所有位置的logits,计算loss
   - 推理时: 只输出最后一个位置的logits,用于生成
输入 idx (batch_size, seq_len)
  ↓
wte: 词嵌入 (batch_size, seq_len, n_embd)
  ↓
wpe: 位置编码 (batch_size, seq_len, n_embd)
  ↓
drop: Dropout (batch_size, seq_len, n_embd)
  ↓
encoder: 编码器 (batch_size, seq_len, n_embd)
  ↓
decoder: 解码器 (batch_size, seq_len, n_embd)
  ↓
lm_head: 线性层 (batch_size, seq_len, vocab_size)
  ↓
logits (batch_size, seq_len, vocab_size)

关键组件

训练 vs 推理

各模块作用

1、词嵌入层:将 token 索引转换为向量

2、位置编码:添加位置信息

3、编码器:提取特征

4、解码器:生成输出

5、输出层:输出词表概率

6、权重初始化:正态分布