从零训练大模型之模型搭建

106 阅读13分钟

前言

进过前两篇,我们已经完成数据集清洗以及BPE词表的训练,并且使用清洗后的数据集构建出 DataSet 类和 custom_collate_fn 方法。这一篇,我们终于要开始搭建模型了,如最开始所说,我们计划的就是使用 pytorch 现有的API进行搭建,而不使用 hugging face 等提供的高级API。所以接下来我们开始搭建 Decoder-only Transformer 模型。

Decoder-Only Transformer——在大型语言模型领域取得了巨大成功,例如著名的 GPT 系列模型。顾名思义,这种架构仅使用 Transformer 的解码器部分。它非常适合自回归(Auto-regressive)任务,即根据已生成的序列预测下一个词元,这正是语言模型的核心任务。

由于我们的目标是构建一个能生成文本(如扩写、问答)的模型,并且考虑到从零开始的复杂性,Decoder-Only架构是一个理想的选择。

构建Decoder-Only Transformer核心模块

一个完整的Decoder-only Transformer模型可以拆解为以下几个核心组件:

  • 词嵌入层 (Token Embedding):将输入的词元(token ID)转换为固定维度的向量。
  • 位置编码层 (Positional Encoding):为模型注入词元在序列中的位置信息。
  • 因果自注意力层 (Causal Self-Attention):模型的核心,让每个词元关注序列中它自身以及它之前的词元(不能看到未来!)。
  • 前馈神经网络 (Feed-Forward Network):对注意力层的输出进行进一步的非线性变换。
  • 解码器模块 (Decoder Block):组合自注意力层和前馈网络,并加入残差连接和层归一化。
  • 整体模型架构 (Decoder-only Transformer):堆叠多个解码器模块,并添加最终的输出层。

1. 词嵌入层 (Token Embedding)

计算机不直接理解文字,我们需要将文本切分成词元(token),再将每个词元ID映射成一个稠密的向量,这个过程就是词嵌入。

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int, padding_idx: Optional[int] = None):
        super(TokenEmbedding, self).__init__()
        # nn.Embedding: PyTorch提供的标准嵌入层。
        # vocab_size: 词汇表的大小,即有多少个不同的词元。
        # embed_dim: 每个词元嵌入后的向量维度。
        # padding_idx (可选): 如果指定,该索引对应的嵌入向量会初始化为0,并且在训练中通常不更新。
        #                    这对于处理变长序列时的填充非常有用。
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
        self.embed_dim = embed_dim # 保存嵌入维度,方便后续使用

    def forward(self, tokens: torch.Tensor):
        # tokens: 输入的词元ID序列,形状通常是 (batch_size, seq_len)
        # 输出: 词嵌入序列,形状是 (batch_size, seq_len, embed_dim)
        
        # 乘以 math.sqrt(self.embed_dim) 是一种常见的缩放技巧。
        # 在原始Transformer论文 "Attention Is All You Need" 中被提及,
        # 有助于在后续层(如点积注意力)中保持适当的方差,防止梯度过小或过大,
        # 使得模型训练更稳定。
        return self.embedding(tokens) * math.sqrt(self.embed_dim)

2. 位置编码层 (Positional Encoding)

Transformer模型本身(尤其是自注意力机制)并不包含序列中词元的顺序信息。为了让模型理解词语的先后关系,我们需要显式地加入位置编码。这里我们使用经典的正弦/余弦位置编码。

这种编码方式的优点是:

  • 对于每个位置,它都能生成一个唯一的编码。
  • 模型可以学习到相对位置信息,因为固定偏移量 k 后的位置编码 PE(pos+k) 可以表示为 PE(pos) 的线性函数。
  • 它可以扩展到比训练时遇到的序列更长的序列。
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim: int, dropout: float = 0.1, max_len: int = 5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout) # Dropout层,用于正则化,防止过拟合

        # 创建一个足够长的位置编码矩阵 pe,形状为 (max_len, embed_dim)
        # max_len 是模型能处理的最大序列长度
        pe = torch.zeros(max_len, embed_dim)
        
        # 生成位置索引 (0, 1, ..., max_len-1),并增加一个维度变为 (max_len, 1)
        # 例如,如果 max_len = 5, position = [[0], [1], [2], [3], [4]]
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        # 计算除法项,用于缩放不同频率的正弦和余弦波
        # div_term 的思想是让不同维度上的波长呈几何级数变化
        # torch.arange(0, embed_dim, 2) 会生成 [0, 2, ..., embed_dim-2]
        # -math.log(10000.0) / embed_dim 是一个缩放因子
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
        
        # 使用正弦函数填充偶数索引的维度 (0, 2, 4, ...)
        pe[:, 0::2] = torch.sin(position * div_term)
        # 使用余弦函数填充奇数索引的维度 (1, 3, 5, ...)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # 增加一个 batch 维度,变为 (1, max_len, embed_dim),以方便后续与输入 (batch, seq_len, embed_dim) 相加时的广播
        pe = pe.unsqueeze(0)
        
        # 将 pe 注册为模型的 buffer。
        # buffer 是模型状态的一部分(会随模型一起保存和加载),但不是可训练参数(即梯度不会流过它们)。
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor):
        # x: 词嵌入序列,形状 (batch_size, seq_len, embed_dim)
        
        # 从预计算的 pe 中取出与当前输入序列长度 (x.size(1)) 相匹配的部分,
        # 并将其加到输入 x 上。这里利用了PyTorch的广播机制。
        # self.pe[:, :x.size(1), :] 会截取 pe 的前 seq_len 个位置编码
        x = x + self.pe[:, :x.size(1), :]
        
        # 应用 dropout 后返回
        # 输出: 带有位置信息的嵌入序列,形状 (batch_size, seq_len, embed_dim)
        return self.dropout(x)

3.因果自注意力层 (Causal Self-Attention)

这是Transformer的魔法核心!自注意力允许模型在处理一个词元时,权衡序列中其他词元的重要性。在Decoder-only模型中,这种注意力必须是“因果”的,即一个词元在计算其表示时,只能关注它自己和它之前的词元,绝不能“偷看”未来的词元。这是因为在生成文本时,我们是逐词预测的。

PyTorch的nn.MultiheadAttention模块为我们提供了强大的支持。

class CausalSelfAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.1):
        super(CausalSelfAttention, self).__init__()
        # embed_dim 必须能被 num_heads 整除
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        
        # nn.MultiheadAttention: PyTorch实现的多头注意力机制。
        # embed_dim: 总的输入/输出特征维度。
        # num_heads: 注意力头的数量。多头允许模型在不同子空间中共同学习信息。
        # dropout: 应用于注意力权重图的dropout概率。
        # batch_first=True: 指定输入和输出张量的形状为 (batch_size, seq_len, embed_dim),更符合常见习惯。
        self.mha = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)

    def forward(self, x: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None):
        # x: 输入序列,形状 (batch_size, seq_len, embed_dim)
        # key_padding_mask (可选): 形状 (batch_size, seq_len)。
        #   用于指示输入序列中的哪些位置是填充(padding)的。
        #   如果一个位置为True,则该位置在注意力计算中会被忽略。

        device = x.device
        seq_len = x.size(1)
        causal_mask = generate_square_subsequent_mask(seq_len, device=device)

        # 对于自注意力 (Self-Attention),query, key, value 都来自同一个输入 x。
        # is_causal=True (PyTorch >= 1.12): 这是关键!
        #   当设置为True时,`MultiheadAttention`会自动应用一个上三角掩码(causal mask),
        #   确保每个位置的query只能关注到key序列中当前及之前的位置。
        #   这样就实现了因果性,防止模型看到未来的信息。
        # attn_mask=causal_mask: 额外的自定义掩码(比如结合padding mask),可以更复杂地构造。
        # key_padding_mask: 传入我们之前生成的padding掩码,屏蔽padding token的影响。
        # need_weights=False: 如果我们不需要返回注意力权重图(通常在训练和推理中为了效率不返回),可以设为False。
        #                   如果需要分析注意力,可以设为True,此时会多一个attn_weights输出。
        attn_output, _ = self.mha(query=x, key=x, value=x,
                                  attn_mask=causal_mask,
                                  key_padding_mask=key_padding_mask,
                                  is_causal=True,
                                  need_weights=False)
        
        # attn_output 形状: (batch_size, seq_len, embed_dim)
        return attn_output
# 生成因果掩码(上三角掩码)的辅助函数
def generate_square_subsequent_mask(sz: int, device: torch.device) -> torch.Tensor:
    """为因果注意力生成一个方阵掩码。"""
    # torch.triu 创建一个上三角矩阵。diagonal=1 表示不包括主对角线。
    # 我们希望屏蔽未来的token,即上三角部分为 True。
    # 这会生成一个 sz x sz 的矩阵,其中上三角(不含对角线)为True,其余为False。
    # 例如 sz=3:
    # [[False,  True,  True],
    #  [False, False,  True],
    #  [False, False, False]]
    # 这正是 attn_mask 所需的格式,True 表示该位置在注意力计算中应被忽略。
    mask = torch.triu(torch.ones(sz, sz, device=device, dtype=torch.bool), diagonal=1)
    return mask

4.前馈神经网络 (Feed-Forward Network, FFN)

在注意力层之后,每个位置的输出会独立地通过一个简单的前馈神经网络。这个网络通常由两个线性层和一个非线性激活函数组成。

FFN的作用是增加模型的非线性表达能力,对注意力机制捕捉到的信息进行更复杂的加工。

class FeedForward(nn.Module):
    def __init__(self, embed_dim: int, ff_dim: int, dropout: float = 0.1):
        super(FeedForward, self).__init__()
        # 第一个线性层,将 embed_dim 扩展到 ff_dim (通常 ff_dim 是 embed_dim 的2到4倍)
        self.linear1 = nn.Linear(embed_dim, ff_dim)
        self.dropout = nn.Dropout(dropout) # Dropout层
        # 第二个线性层,将 ff_dim 缩减回 embed_dim
        self.linear2 = nn.Linear(ff_dim, embed_dim)
        # 激活函数,ReLU 是Transformer原始论文中使用的,GELU也是现代LLM中的常用选择
        self.activation = nn.ReLU() # 或者 nn.GELU()

    def forward(self, x: torch.Tensor):
        # x: 输入形状 (batch_size, seq_len, embed_dim)
        x = self.linear1(x)      # (batch_size, seq_len, ff_dim)
        x = self.activation(x)   # (batch_size, seq_len, ff_dim)
        x = self.dropout(x)      # (batch_size, seq_len, ff_dim)
        x = self.linear2(x)      # (batch_size, seq_len, embed_dim)
        # 输出形状: (batch_size, seq_len, embed_dim)
        return x

5. 解码器模块 (Decoder Block)

一个解码器模块将上述的因果自注意力层和前馈网络组合起来。关键的是,在每个子层(自注意力、FFN)的输出之后,都会使用残差连接(Residual Connection)和层归一化(Layer Normalization)。

  • 残差连接 (x + sublayer(x)):允许梯度直接流过网络,有助于训练更深的模型,缓解梯度消失问题。
  • 层归一化 (LayerNorm(x)):对每个样本在特征维度上进行归一化,有助于稳定训练过程,加速收敛。
class DecoderBlock(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, dropout: float = 0.1):
        super(DecoderBlock, self).__init__()
        self.self_attention = CausalSelfAttention(embed_dim, num_heads, dropout)
        self.feed_forward = FeedForward(embed_dim, ff_dim, dropout)
        
        # 层归一化 (Layer Normalization)
        # Transformer 原始论文中使用的是 Post-LN 结构 (LN 在残差连接之后)。
        # Pre-LN (LN 在自注意力/前馈网络之前) 也是一种常见的变体,有时能提供更稳定的训练。
        # 这里我们遵循原始论文的 Post-LN 结构。
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        # Dropout 应用于子层的输出,在加入残差连接之前
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None):
        # x: 输入形状 (batch_size, seq_len, embed_dim)
        # key_padding_mask (可选): 形状 (batch_size, seq_len), 用于CausalSelfAttention
        
        # 1. 因果自注意力子层 (Multi-Head Attention)
        residual = x # 保存输入 x 用于第一个残差连接
        attn_output = self.self_attention(x, key_padding_mask=key_padding_mask)
        # 应用 dropout,然后进行残差连接 (Add) 和层归一化 (Norm)
        # x = LayerNorm(x + Dropout(Sublayer(x)))
        x = self.norm1(residual + self.dropout1(attn_output))
        
        # 2. 前馈网络子层 (Feed Forward)
        residual = x # 更新残差连接的来源为上一层的输出
        ff_output = self.feed_forward(x)
        # 再次应用 dropout,然后进行残差连接 (Add) 和层归一化 (Norm)
        x = self.norm2(residual + self.dropout2(ff_output))
        
        # 输出形状: (batch_size, seq_len, embed_dim)
        return x

6. 组装!Decoder-only Transformer 模型

现在,我们将所有部件组装起来,构建完整的DecoderOnlyTransformer模型。它主要包含:

  • 词嵌入层
  • 位置编码层
  • 一叠(num_layers个)上面定义的DecoderBlock
  • 一个最终的线性输出层,将embed_dim维的向量映射回词汇表大小,得到每个词的预测概率(logits)。
class DecoderOnlyTransformer(nn.Module):
    def __init__(self, num_layers: int, vocab_size: int, embed_dim: int, 
                 num_heads: int, ff_dim: int, max_seq_len: int, 
                 dropout: float = 0.1, padding_idx: Optional[int] = 0,
                 tie_weights: bool = True): # 新增 tie_weights 参数
        super(DecoderOnlyTransformer, self).__init__()
        
        self.padding_idx = padding_idx # 保存padding_idx,用于生成padding_mask和传递给TokenEmbedding

        # 1. 词嵌入层
        self.token_embedding = TokenEmbedding(vocab_size, embed_dim, padding_idx=self.padding_idx)
        # 2. 位置编码层
        self.positional_encoding = PositionalEncoding(embed_dim, dropout, max_seq_len)
        
        # 3. Decoder Block 堆栈
        # 使用 nn.ModuleList 来正确注册 DecoderBlock 列表中的模块,使其参数能被PyTorch自动管理。
        self.decoder_blocks = nn.ModuleList(
            [DecoderBlock(embed_dim, num_heads, ff_dim, dropout) for _ in range(num_layers)]
        )
        
        # 4. 输出层 (也称为语言模型头, LM Head)
        # 将Decoder的最终输出(embed_dim维向量)映射回词汇表大小,得到每个词元的原始分数(logits)。
        # 后续可以通过Softmax将其转换为概率分布。
        self.output_layer = nn.Linear(embed_dim, vocab_size)
        
        # 5. (可选但推荐) 权重绑定 (Weight Tying)
        # 这是一个常见的技巧:共享词嵌入层 (self.token_embedding) 和输出层 (self.output_layer) 的权重矩阵。
        # 思想是:能够很好地表示一个词的向量,也应该能够很好地从隐藏状态预测出这个词。
        # 这可以显著减少模型参数数量,并有时能提高性能,特别是在词汇表较大时。
        if tie_weights:
            # 确保嵌入维度和输出层输入维度一致,这是权重绑定的前提
            if embed_dim != self.token_embedding.embedding.weight.size(1):
                raise ValueError("embed_dim must match embedding dim for weight tying")
            # 直接将输出层的权重指向嵌入层的权重
            self.output_layer.weight = self.token_embedding.embedding.weight

        # 初始化模型参数 (一个好的实践)
        self._init_weights()

    def _init_weights(self):
        # 对模型中的不同类型的层进行参数初始化是一种常见的实践,有助于模型训练。
        # 例如,Xavier/Glorot 初始化常用于线性层和嵌入层。
        for module in self.modules(): # self.modules() 会递归地返回模型中所有的模块
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight) # 使用 Xavier 均匀初始化权重
                if module.bias is not None:
                    nn.init.zeros_(module.bias) # 将偏置初始化为0
            elif isinstance(module, nn.Embedding):
                nn.init.xavier_uniform_(module.weight) # 使用 Xavier uniform 初始化嵌入权重
                if module.padding_idx is not None:
                    # 特别地,确保 padding_idx 对应的嵌入向量是零,并且在训练中(如果优化器不特殊处理)不会被更新。
                    module.weight.data[module.padding_idx].zero_() 
            elif isinstance(module, nn.LayerNorm):
                # LayerNorm 的 gamma (weight) 通常初始化为1,beta (bias) 初始化为0
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)


    def _generate_padding_mask(self, src: torch.Tensor) -> Optional[torch.Tensor]:
        # src: 输入的词元ID序列,形状 (batch_size, seq_len)
        # 功能: 生成一个布尔掩码,标记出输入序列中的padding位置。
        #       True表示对应位置是padding,应该在注意力计算中被忽略。
        if self.padding_idx is None:
            return None # 如果没有定义padding_idx,则不生成掩码
        
        src_padding_mask = (src == self.padding_idx) # 形状 (batch_size, seq_len)
        # 例如: src = [[1, 2, 0], [3, 0, 0]], padding_idx = 0
        # mask =    [[F, F, T], [F, T, T]]
        return src_padding_mask

    def forward(self, src: torch.Tensor):
        # src: 输入的词元ID序列,形状 (batch_size, seq_len)
        #      例如: [[101, 1034, 203, 0, 0], [101, 405, 589, 382, 0]] (0是padding_idx)
        
        # 1. 生成 padding 掩码 (key_padding_mask)
        #    这个掩码会传递给CausalSelfAttention,用于在计算注意力分数时忽略padding token。
        src_key_padding_mask = self._generate_padding_mask(src) 
        # 形状: (batch_size, seq_len) 或者 None

        # 2. 词嵌入 和 位置编码
        # (batch_size, seq_len) -> (batch_size, seq_len, embed_dim)
        x = self.token_embedding(src)      
        # (batch_size, seq_len, embed_dim) -> (batch_size, seq_len, embed_dim)
        x = self.positional_encoding(x) 
        
        # 3. 通过 Decoder 块堆栈
        #    每一层 DecoderBlock 都会接收 src_key_padding_mask,
        #    其内部的 CausalSelfAttention 会使用 is_causal=True 自动处理因果掩码。
        for block in self.decoder_blocks:
            x = block(x, key_padding_mask=src_key_padding_mask) 
            # x 形状保持: (batch_size, seq_len, embed_dim)
            
        # 4. 输出层
        #    将Decoder最终的隐藏状态通过线性层映射到词汇表空间,得到logits。
        # (batch_size, seq_len, embed_dim) -> (batch_size, seq_len, vocab_size)
        logits = self.output_layer(x) 
        
        return logits # 返回每个位置上,对词汇表中每个词的预测分数

总结

你已经成功地从零开始,使用PyTorch搭建了一个完整的Decoder-only Transformer模型。回顾一下,我们实现了:

  • TokenEmbedding: 将词ID转为向量。
  • PositionalEncoding: 注入位置信息。
  • CausalSelfAttention: 实现带因果约束的核心注意力机制。
  • FeedForward: 增加非线性。
  • DecoderBlock: 模块化组合,包含残差和层归一化。
  • DecoderOnlyTransformer: 整合所有组件,并加入权重绑定和参数初始化等实用技巧。

关注我的公众号不走丢

附录

GitHub链接:github.com/JimmysAIPG/…