2.3 解码器结构、因果掩码与归一化:原书第2章实现细节速览

0 阅读7分钟

基于《大规模语言模型:从理论到实践(第2版)》第2章 大语言模型基础

爆款小标题:写代码/读源码必备:Decoder-only 的掩码、残差与 Norm 怎么堆


为什么这一节重要

真正读 Transformer 解码器代码或自己实现时,会遇到三类容易出错或容易混淆的细节:因果掩码(如何保证自回归、训练和推理时如何施加)、归一化(Pre-Norm 和 Post-Norm 谁在前谁在后、残差加在哪)、以及 RMSNorm 与 LayerNorm 的差异。这些在原书第 2 章都有对应,本节把它们串成「实现视角」的一讲,便于你对照任意一份 Decoder-only 代码(如 LLaMA、GPT-2)快速理清结构并避免常见 bug。


学习目标

学完本节,你将能够:

  • 说清因果掩码:解释在自注意力中如何通过掩码保证「每个位置只能看到左侧」,以及在训练与推理时如何施加(如 softmax 前赋 -inf)。
  • 区分 Pre-Norm 与 Post-Norm:说明「归一化—子层—残差」的顺序差异,以及 Pre-Norm 在训练稳定性上的常见优势;能指出残差是加在子层之前还是之后。
  • 对照代码指出块结构:在给定一份 Decoder Block 代码时,能指出「哪里是因果 mask」「哪里是 Norm」「哪里是 Attention / FFN」及顺序。

一、因果掩码(Causal Mask)的作用与实现(原书第 2 章)

为什么需要因果掩码

自回归语言模型在训练时的目标是:根据位置 (1 \sim t-1) 的 token 预测位置 (t) 的 token。因此,在计算位置 (t) 的表示时,不能使用位置 (t+1, t+2, \ldots, n) 的信息,否则就「偷看答案」了。注意力机制本身会看到整句,所以必须在注意力权重上屏蔽未来位置,使位置 (t) 对所有 (j>t) 的注意力权重为 0。

如何施加

设注意力分数为 (A = QK^\top / \sqrt{d_k}),形状 ((n, n)),其中 (A_{i,j}) 表示位置 (i) 对位置 (j) 的分数。在送入 softmax 之前,把所有「(j > i)」的位置置为负无穷(或一个很大的负数):

  • 例如构造一个下三角为 0、上三角为 (-\infty) 的 mask 矩阵,与 (A) 相加;或只在有效位置做 softmax。
  • softmax 之后,未来位置的权重变为 0,当前位置只能对「自己及过去」的位置加权,从而保证自回归性质。

训练 vs 推理

  • 训练:通常一次性对整段序列做前向,通过上述因果 mask 保证每个位置只依赖左侧;损失对每个位置(或忽略 padding)的下一 token 预测求平均。
  • 推理:逐 token 生成时,每次只多一个 token,等价于序列长度逐步增加;实现上可以是「每次用当前序列做前向、只取最后一个位置的输出」,或复用 KV cache 并只算新 token 的注意力。无论哪种,因果性都通过「只允许看已生成部分」来保证,与训练时的 mask 逻辑一致。因此「推理时不需要掩码」是错的——推理时仍然不能看未来,只是实现上可能是「序列变长 + 同一套因果规则」。

常见 bug:实现或修改 Decoder 时若漏掉因果 mask,或把 mask 搞反(允许看未来),会导致训练时泄露、生成时行为异常;务必在代码里明确「哪里构造 mask、哪里加到注意力分数上」。


二、Pre-Norm 与 Post-Norm(原书第 2 章)

Post-Norm(原始 Transformer 解码器)

子层输出后再做归一化,然后加残差:

[ x' = \text{LayerNorm}(x + \text{Sublayer}(x)) ]

即:先子层(Attention 或 FFN),再「残差 + Norm」。早期 Transformer 论文中解码器采用这种形式。

Pre-Norm(当前主流,如 LLaMA)

先做归一化,再进子层,然后加残差:

[ x' = x + \text{Sublayer}(\text{LayerNorm}(x)) ]

即:先 Norm,再子层,最后只对子层输出做残差相加。多数新架构(LLaMA、GPT-NeoX 等)采用 Pre-Norm。

为什么 Pre-Norm 更常用

Pre-Norm 使「主路径」先经过 Norm 再进入子层,梯度流动更平滑,深层时训练更稳定;Post-Norm 在深层时容易出现梯度或激活尺度问题。因此读代码时若看到「先 Norm 再 Attention/FFN、最后加残差」,就是 Pre-Norm;若看到「先 Attention/FFN、再 Norm 再加残差」,就是 Post-Norm。残差加在哪:Pre-Norm 是「加在子层之后」(与输入 (x) 相加);Post-Norm 是「加在子层之后、Norm 之前」(子层输出与输入 (x) 相加后再 Norm)。

与基座保持一致:若在已有基座上微调或修改,不要随意把 Pre-Norm 改成 Post-Norm 或反过来,否则可能训练发散或效果变差。


三、RMSNorm 与 LayerNorm(原书第 2 章)

LayerNorm

对最后一维(特征维)做标准化:减均值、除标准差,再做仿射变换(可学习缩放与平移)。公式为 (\text{LN}(x) = \gamma \cdot \frac{x - \mu}{\sigma} + \beta)。

RMSNorm(Root Mean Square Layer Normalization)

去掉「减均值」与平移项,只保留「按 RMS(均方根)缩放」及可学习的缩放因子:

[ \text{RMSNorm}(x) = \gamma \cdot \frac{x}{\text{RMS}(x)}, \quad \text{RMS}(x) = \sqrt{\frac{1}{d}\sum_i x_i^2} ]

计算更省、实现更简单,在 LLaMA 等模型中效果与 LayerNorm 相当,被广泛采用。原书第 2 章在介绍 LLaMA 时说明了这一点。

实现与换基座时:若基座是 RMSNorm,微调或扩展时保持 RMSNorm;若混用 LayerNorm 和 RMSNorm,可能影响训练稳定性或效果,一般不推荐随意替换。


四、一层 Decoder Block 的典型顺序(原书第 2 章)

结合原书与常见实现,一个 Decoder Block 通常为(以 Pre-Norm + RMSNorm 为例):

  1. RMSNorm(对输入 (x))
  2. 因果自注意力(Q/K/V + 因果 mask + 多头输出)
  3. 残差:(x = x + \text{Attention}(\ldots))
  4. RMSNorm(对当前 (x))
  5. 前馈网络 FFN(如 SwiGLU)
  6. 残差:(x = x + \text{FFN}(\ldots))

输出 (x) 传入下一 Block。阅读代码时,可按「Norm → Attention → 残差 → Norm → FFN → 残差」的顺序对号入座;同时确认注意力那里是否有因果 mask(以及可选的 padding mask)。


五、工程实战要点

1. 自实现或修改 Decoder 时检查 mask 与 Norm

  • 确认因果 mask 在 softmax 前正确施加(未来位置为 -inf 或等价处理)。
  • 若有 padding,需把 padding 位置也 mask 掉(通常与因果 mask 合并成一个 mask 矩阵)。
  • 确认 Norm 类型与顺序与基座一致(Pre-Norm vs Post-Norm、RMSNorm vs LayerNorm)。

2. 换基座时的注意点

不同基座可能用 LayerNorm 或 RMSNorm、Pre-Norm 或 Post-Norm;迁移或微调时保持一致,避免因 Norm 不一致导致训练发散或效果异常。


六、常见误区与避坑指南

误区一:认为「推理时不需要掩码」

推理时仍要保证自回归:只能基于已生成的 token 预测下一个。实现上可能是「序列逐步变长 + 同一套因果规则」或「KV cache + 只算新 token 的注意力」,但逻辑上等价于「不能看未来」。避坑:推理代码里要么显式用因果 mask,要么用只含已生成 token 的序列做前向,二者必居其一。

误区二:Pre-Norm 和 Post-Norm 混用

与基座不一致的 Norm 顺序可能带来训练不稳定。避坑:和基座保持一致,不随意改 Norm 位置。

误区三:忽略 padding 的 mask

在 batch 内序列长度不一、需要 padding 时,若只做因果 mask 而不 mask 掉 padding 位置,padding 位置会参与注意力计算并影响输出。避坑:把「因果 mask」和「padding mask」合并(例如 padding 位置也设为 -inf),再与注意力分数相加。


七、小结与衔接

本节围绕原书第 2 章的实现细节,讲解了:因果掩码的作用与施加方式(softmax 前对未来位置赋 -inf)、训练与推理时都需保证自回归;Pre-Norm 与 Post-Norm 的区别及残差加在哪;RMSNorm 与 LayerNorm 的差异;以及一层 Decoder Block 的典型顺序(Norm → Attention → 残差 → Norm → FFN → 残差)。掌握这些后,读 LLaMA、GPT-2 等任意 Decoder-only 代码即可快速对应。下一模块将进入预训练数据与训练流程:数据从哪来、怎么洗、以及从原始语料到 loss 与 checkpoint 的完整管线(原书第 3–4 章)。


课后思考题

  1. 在注意力公式中,因果掩码是如何施加的?(提示:在 softmax 前对非法位置赋值为负无穷。)
  2. Pre-Norm 的「残差」是加在子层之前还是之后?与 Post-Norm 相比,梯度流动有什么不同?