04-残差连接与Pre-LN:让大模型的深度网络成为可能

5 阅读7分钟

深度网络的困境

在前面的章节中,我们学习了注意力机制、位置编码和MLP层。现在让我们把它们组合成一个完整的Transformer层:

步骤1:多头注意力X1=MultiHeadAttention(X)步骤2:MLP前馈网络X2=MLP(X1)\begin{aligned} &\text{步骤1:多头注意力} \\ &X_1 = \text{MultiHeadAttention}(X) \\ \\ &\text{步骤2:MLP前馈网络} \\ &X_2 = \text{MLP}(X_1) \end{aligned}

问题来了:如果我们要堆叠很多层(比如GPT-3有96层),会发生什么?

梯度消失与梯度爆炸

深度神经网络的训练依赖反向传播:梯度从输出层一层层往回传,更新每一层的参数。

链式法则

LW1=LX96X96X95X95X94X2X1X1W1\frac{\partial L}{\partial W_1} = \frac{\partial L}{\partial X_{96}} \cdot \frac{\partial X_{96}}{\partial X_{95}} \cdot \frac{\partial X_{95}}{\partial X_{94}} \cdots \frac{\partial X_2}{\partial X_1} \cdot \frac{\partial X_1}{\partial W_1}

这是一个连乘!

梯度消失

如果每一层的梯度都小于1(比如0.9):

0.9960.00003(几乎为0!)0.9^{96} \approx 0.00003 \quad \text{(几乎为0!)}
  • 底层(靠近输入)的梯度变得极小
  • 参数几乎不更新
  • 模型无法有效学习

梯度爆炸

如果每一层的梯度都大于1(比如1.1):

1.1968533(爆炸!)1.1^{96} \approx 8533 \quad \text{(爆炸!)}
  • 梯度变得极大
  • 参数更新幅度过大
  • 训练不稳定,模型发散

历史事实

在ResNet(2015)之前,训练超过20层的网络都很困难。直接堆叠更多层,效果反而变差!

问题的本质

信息流的退化

在深度网络中,信息需要经过很多层的变换才能传递:

Xf1(X)f2(f1(X))f3(f2(f1(X)))X \to f_1(X) \to f_2(f_1(X)) \to f_3(f_2(f_1(X))) \to \cdots
  • 每经过一层,信息都会被"扭曲"和"压缩"
  • 层数越深,原始信息越难保留
  • 梯度也面临同样的问题

直观类比

想象一个传话游戏:

  • 第1个人对第2个人说:"今天天气真好"
  • 第2个人理解后传给第3个人:"今天不错"
  • 第3个人传给第4个人:"挺好"
  • ...
  • 第96个人听到的可能是:"好"(信息几乎丢失!)

这就是为什么需要残差连接!

残差连接(Residual Connection)

核心思想

残差连接的想法非常简单:在变换的同时,保留原始信息的"高速通道"

没有残差连接

Xout=f(Xin)X_{\text{out}} = f(X_{\text{in}})

信息必须经过函数 ff 的变换。

有残差连接

Xout=Xin+f(Xin)X_{\text{out}} = X_{\text{in}} + f(X_{\text{in}})
  • XinX_{\text{in}}:直接传递的原始信息(恒等映射,identity)
  • f(Xin)f(X_{\text{in}}):学习到的"残差"(residual,即修正/补充)
  • 两者相加:原始信息 + 修正

关键洞察

函数 ff 不需要学习完整的映射,只需要学习"差异"或"修正"!

数学原理

1. 梯度流的改善

有残差连接时,反向传播的链式法则变为:

XoutXin=Xin[Xin+f(Xin)]=I+f(Xin)Xin\frac{\partial X_{\text{out}}}{\partial X_{\text{in}}} = \frac{\partial}{\partial X_{\text{in}}} \left[ X_{\text{in}} + f(X_{\text{in}}) \right] = I + \frac{\partial f(X_{\text{in}})}{\partial X_{\text{in}}}

其中 II 是单位矩阵(恒等映射的梯度)。

关键:即使 fXin\frac{\partial f}{\partial X_{\text{in}}} 很小甚至为0,梯度仍然至少有 II(值为1)!

LXin=LXout(I+fXin)\frac{\partial L}{\partial X_{\text{in}}} = \frac{\partial L}{\partial X_{\text{out}}} \cdot \left( I + \frac{\partial f}{\partial X_{\text{in}}} \right)

梯度可以直接通过恒等映射传递,不会消失!

2. 多层残差连接的累积效果

假设有 nn 层,每层都有残差连接:

X1=X0+f1(X0)X2=X1+f2(X1)=X0+f1(X0)+f2(X1)X3=X2+f3(X2)=X0+f1(X0)+f2(X1)+f3(X2)Xn=X0+i=1nfi(Xi1)\begin{aligned} X_1 &= X_0 + f_1(X_0) \\ X_2 &= X_1 + f_2(X_1) = X_0 + f_1(X_0) + f_2(X_1) \\ X_3 &= X_2 + f_3(X_2) = X_0 + f_1(X_0) + f_2(X_1) + f_3(X_2) \\ &\vdots \\ X_n &= X_0 + \sum_{i=1}^{n} f_i(X_{i-1}) \end{aligned}

发现:最终输出 XnX_n 包含原始输入 X0X_0 加上所有层的"修正"累积!

梯度传播

XnX0=I+X0i=1nfi(Xi1)\frac{\partial X_n}{\partial X_0} = I + \frac{\partial}{\partial X_0} \sum_{i=1}^{n} f_i(X_{i-1})
  • 始终有恒等项 II
  • 梯度可以直接从第 nn 层传到第 0 层
  • 不会因为层数增加而消失

3. 直观理解:多条路径

残差连接创造了指数级的路径

对于3层网络:

  • 没有残差:1条路径(f3f2f1f_3 \circ f_2 \circ f_1
  • 有残差:8条路径!
X3=X0+f1+f2+f3+f1f2+f1f3+f2f3+f1f2f3\begin{aligned} X_3 &= X_0 + f_1 + f_2 + f_3 \\ &\quad + f_1 \circ f_2 + f_1 \circ f_3 + f_2 \circ f_3 \\ &\quad + f_1 \circ f_2 \circ f_3 \end{aligned}

每一层可以选择"使用"或"跳过",形成多条并行路径,梯度可以通过任意路径流动。

Transformer中的残差连接

在Transformer的每一层中,残差连接被应用在两个地方

1. 注意力子层

X1=X+MultiHeadAttention(X)X_1 = X + \text{MultiHeadAttention}(X)
  • XX:子层的输入
  • MultiHeadAttention(X)\text{MultiHeadAttention}(X):注意力的输出(学习到的修正)
  • X1X_1:相加后的输出

2. MLP子层

X2=X1+MLP(X1)X_2 = X_1 + \text{MLP}(X_1)
  • X1X_1:MLP子层的输入
  • MLP(X1)\text{MLP}(X_1):前馈网络的输出(学习到的修正)
  • X2X_2:相加后的输出

完整的一层Transformer

步骤1:注意力+残差X1=X+MultiHeadAttention(X)步骤2:MLP+残差X2=X1+MLP(X1)\begin{aligned} \text{步骤1:注意力+残差} \quad &X_1 = X + \text{MultiHeadAttention}(X) \\ \text{步骤2:MLP+残差} \quad &X_2 = X_1 + \text{MLP}(X_1) \end{aligned}

残差连接的效果

实验证据(ResNet论文):

网络深度无残差连接有残差连接
18层✅ 能训练✅ 能训练
34层⚠️ 勉强训练✅ 能训练
50层❌ 难以训练✅ 能训练
101层❌ 无法训练✅ 能训练
152层❌ 无法训练✅ 能训练

Transformer的应用

  • GPT-3:96层
  • GPT-4:推测120+层
  • PaLM:118层

没有残差连接,这些深度模型不可能训练成功!

LayerNorm:稳定训练的另一块基石

残差连接解决了梯度流的问题,但还有一个问题:不同层、不同维度的激活值范围可能差异很大

为什么需要归一化?

问题示例

假设某一层的输出:

X=[1000.1502000.05801500.260]X = \begin{bmatrix} 100 & 0.1 & 50 \\ 200 & 0.05 & 80 \\ 150 & 0.2 & 60 \end{bmatrix}
  • 第1维:范围100-200(很大)
  • 第2维:范围0.05-0.2(很小)
  • 第3维:范围50-80(中等)

导致的问题

  1. 梯度不平衡

    • 大值维度的梯度很大
    • 小值维度的梯度很小
    • 参数更新不均衡
  2. 数值不稳定

    • softmax、sigmoid等函数对大数值敏感
    • 可能出现 e100e^{100} 导致溢出
  3. 学习效率低

    • 需要仔细调整学习率
    • 训练速度慢

LayerNorm的定义

Layer Normalization对每个样本的所有特征维度进行归一化:

LayerNorm(x)=γxμσ2+ϵ+β\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta

参数解释

  • xRdmodelx \in \mathbb{R}^{d_{\text{model}}}:输入向量(单个Token的表示)
  • μ=1dmodeli=1dmodelxi\mu = \frac{1}{d_{\text{model}}} \sum_{i=1}^{d_{\text{model}}} x_i:该向量的均值
  • σ2=1dmodeli=1dmodel(xiμ)2\sigma^2 = \frac{1}{d_{\text{model}}} \sum_{i=1}^{d_{\text{model}}} (x_i - \mu)^2:该向量的方差
  • ϵ\epsilon:防止除零的小常数(通常 10510^{-5}10610^{-6}
  • γ,βRdmodel\gamma, \beta \in \mathbb{R}^{d_{\text{model}}}:可学习的缩放和平移参数

步骤分解

步骤1:计算均值μ=1di=1dxi步骤2:计算方差σ2=1di=1d(xiμ)2步骤3:标准化x^i=xiμσ2+ϵ步骤4:缩放和平移yi=γix^i+βi\begin{aligned} \text{步骤1:计算均值} \quad &\mu = \frac{1}{d} \sum_{i=1}^{d} x_i \\ \\ \text{步骤2:计算方差} \quad &\sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2 \\ \\ \text{步骤3:标准化} \quad &\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \\ \\ \text{步骤4:缩放和平移} \quad &y_i = \gamma_i \cdot \hat{x}_i + \beta_i \end{aligned}

效果

  • 步骤3后:x^\hat{x} 的均值为0,方差为1(强制标准化)
  • 步骤4:通过可学习的 γ\gammaβ\beta,让模型自己决定最优的分布

γ\gammaβ\beta 是如何学习的?

γ\gammaβ\beta可学习的参数,和模型的权重矩阵(如 WQW_QW1W_1 等)完全一样,通过反向传播和梯度下降进行训练。

1. 初始化

在训练开始前,γ\gammaβ\beta 需要初始化:

γ=[1,1,1,,1]Rdmodel(初始化为全1)β=[0,0,0,,0]Rdmodel(初始化为全0)\begin{aligned} \gamma &= [1, 1, 1, \ldots, 1] \in \mathbb{R}^{d_{\text{model}}} \quad \text{(初始化为全1)} \\ \beta &= [0, 0, 0, \ldots, 0] \in \mathbb{R}^{d_{\text{model}}} \quad \text{(初始化为全0)} \end{aligned}

为什么这样初始化?

  • 这样初始状态下:y=1x^+0=x^y = 1 \cdot \hat{x} + 0 = \hat{x}
  • 相当于直接使用标准化后的结果(均值0,方差1)
  • 模型可以从这个"中性"状态开始学习最优分布

2. 前向传播

在前向传播中,LayerNorm计算输出:

y=γx^+βy = \gamma \odot \hat{x} + \beta

其中 \odot 表示逐元素乘法。

示例dmodel=4d_{\text{model}}=4):

x^=[0.169,1.183,0.507,1.521](标准化后)γ=[1.2,0.8,1.5,0.9](训练学到的)β=[0.1,0.2,0.3,0.0](训练学到的)y1=1.2×0.169+0.1=0.303y2=0.8×(1.183)+(0.2)=1.146y3=1.5×(0.507)+0.3=0.461y4=0.9×1.521+0.0=1.369y=[0.303,1.146,0.461,1.369]\begin{aligned} \hat{x} &= [0.169, -1.183, -0.507, 1.521] \quad \text{(标准化后)} \\ \gamma &= [1.2, 0.8, 1.5, 0.9] \quad \text{(训练学到的)} \\ \beta &= [0.1, -0.2, 0.3, 0.0] \quad \text{(训练学到的)} \\ \\ y_1 &= 1.2 \times 0.169 + 0.1 = 0.303 \\ y_2 &= 0.8 \times (-1.183) + (-0.2) = -1.146 \\ y_3 &= 1.5 \times (-0.507) + 0.3 = -0.461 \\ y_4 &= 0.9 \times 1.521 + 0.0 = 1.369 \\ \\ y &= [0.303, -1.146, -0.461, 1.369] \end{aligned}

3. 反向传播

在反向传播时,损失函数 LL 的梯度会传到 γ\gammaβ\beta

Lγi=Lyiyiγi=Lyix^iLβi=Lyiyiβi=Lyi1\begin{aligned} \frac{\partial L}{\partial \gamma_i} &= \frac{\partial L}{\partial y_i} \cdot \frac{\partial y_i}{\partial \gamma_i} = \frac{\partial L}{\partial y_i} \cdot \hat{x}_i \\ \\ \frac{\partial L}{\partial \beta_i} &= \frac{\partial L}{\partial y_i} \cdot \frac{\partial y_i}{\partial \beta_i} = \frac{\partial L}{\partial y_i} \cdot 1 \end{aligned}

直观理解

  • γi\gamma_i 的梯度 = 下游梯度 × 标准化后的值 x^i\hat{x}_i
  • βi\beta_i 的梯度 = 下游梯度(直接传递)

4. 参数更新

使用优化器(如AdamW)更新参数:

γγηLγββηLβ\begin{aligned} \gamma &\leftarrow \gamma - \eta \cdot \frac{\partial L}{\partial \gamma} \\ \beta &\leftarrow \beta - \eta \cdot \frac{\partial L}{\partial \beta} \end{aligned}

其中 η\eta 是学习率。

和其他参数完全一样的训练过程!

# PyTorch中的实现
class LayerNorm(nn.Module):
    def __init__(self, d_model=768):
        super().__init__()
        # 定义可学习参数
        self.gamma = nn.Parameter(torch.ones(d_model))   # 初始化为1
        self.beta = nn.Parameter(torch.zeros(d_model))   # 初始化为0
        self.eps = 1e-6

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        mean = x.mean(dim=-1, keepdim=True)  # (batch, seq_len, 1)
        var = x.var(dim=-1, keepdim=True)    # (batch, seq_len, 1)

        # 标准化
        x_norm = (x - mean) / torch.sqrt(var + self.eps)

        # 缩放和平移(gamma和beta会自动参与梯度更新)
        output = self.gamma * x_norm + self.beta

        return output

# 查看参数
ln = LayerNorm(d_model=768)
print(f"gamma是可学习参数: {ln.gamma.requires_grad}")  # True
print(f"beta是可学习参数: {ln.beta.requires_grad}")    # True
print(f"参数数量: {ln.gamma.numel() + ln.beta.numel()}")  # 768 + 768 = 1536

# 在训练时,optimizer会自动更新这些参数
optimizer = torch.optim.AdamW(ln.parameters(), lr=1e-3)

5. 为什么需要 γ\gammaβ\beta

标准化强制把数据变成均值0方差1,但这不一定是最优的分布

问题:某些层可能需要不同的均值和方差才能更好地学习。

解决方案:通过可学习的 γ\gammaβ\beta,让模型自己决定:

  • γ\gamma:控制每个维度的"缩放"(方差)
  • β\beta:控制每个维度的"偏移"(均值)

极端情况:如果模型学到 γi=σi\gamma_i = \sigma_iβi=μi\beta_i = \mu_i(原始的方差和均值),那就等于恢复了标准化之前的分布!

yi=γix^i+βi=σixiμiσi+μi=xiy_i = \gamma_i \cdot \hat{x}_i + \beta_i = \sigma_i \cdot \frac{x_i - \mu_i}{\sigma_i} + \mu_i = x_i

这给了模型自由度:可以保留归一化的好处,也可以根据需要调整分布。

6. 参数量占比

每个LayerNorm层的参数:

参数量=dmodel×2=768×2=1,536 参数\text{参数量} = d_{\text{model}} \times 2 = 768 \times 2 = 1{,}536 \text{ 参数}

对比一个MLP层(dmodel=768d_{\text{model}}=768, dff=3072d_{\text{ff}}=3072):

MLP参数量=768×3072+3072×7684,700,000 参数\text{MLP参数量} = 768 \times 3072 + 3072 \times 768 \approx 4{,}700{,}000 \text{ 参数}

LayerNorm的参数量不到0.05%,几乎可以忽略不计!但它的作用却至关重要。

具体例子

假设 dmodel=4d_{\text{model}} = 4,某个Token的表示为:

x=[100,0.1,50,200]x = [100, 0.1, 50, 200]

步骤1:计算均值

μ=100+0.1+50+2004=87.525\mu = \frac{100 + 0.1 + 50 + 200}{4} = 87.525

步骤2:计算方差

σ2=(10087.525)2+(0.187.525)2+(5087.525)2+(20087.525)24=156.14+7644.62+1406.14+12650.144=5464.26\begin{aligned} \sigma^2 &= \frac{(100-87.525)^2 + (0.1-87.525)^2 + (50-87.525)^2 + (200-87.525)^2}{4} \\ &= \frac{156.14 + 7644.62 + 1406.14 + 12650.14}{4} \\ &= 5464.26 \end{aligned}
σ=5464.2673.92\sigma = \sqrt{5464.26} \approx 73.92

步骤3:标准化

x^1=10087.52573.920.169x^2=0.187.52573.921.183x^3=5087.52573.920.507x^4=20087.52573.921.521\begin{aligned} \hat{x}_1 &= \frac{100 - 87.525}{73.92} \approx 0.169 \\ \hat{x}_2 &= \frac{0.1 - 87.525}{73.92} \approx -1.183 \\ \hat{x}_3 &= \frac{50 - 87.525}{73.92} \approx -0.507 \\ \hat{x}_4 &= \frac{200 - 87.525}{73.92} \approx 1.521 \end{aligned}
x^=[0.169,1.183,0.507,1.521]\hat{x} = [0.169, -1.183, -0.507, 1.521]

验证:均值 0\approx 0,方差 1\approx 1

步骤4:缩放和平移(假设 γ=[1,1,1,1]\gamma=[1,1,1,1], β=[0,0,0,0]\beta=[0,0,0,0]

y=x^=[0.169,1.183,0.507,1.521]y = \hat{x} = [0.169, -1.183, -0.507, 1.521]

对比

维度原始值归一化后
11000.169
20.1-1.183
350-0.507
42001.521

所有维度现在都在相近的范围内!

LayerNorm vs BatchNorm:归一化的维度差异

在深度学习中,还有一种常见的归一化:Batch Normalization。它们的核心区别在于归一化的维度不同

直观理解:用矩阵来看

假设我们有一个batch的数据,形状为 (N,d)(N, d)

  • NN:batch大小(样本数量),比如32
  • dd:特征维度(每个样本的向量长度),比如768

数据可以表示为一个矩阵:

X=[x1,1x1,2x1,768x2,1x2,2x2,768x32,1x32,2x32,768]样本1的768维特征样本2的768维特征样本32的768维特征X = \begin{bmatrix} x_{1,1} & x_{1,2} & \cdots & x_{1,768} \\ x_{2,1} & x_{2,2} & \cdots & x_{2,768} \\ \vdots & \vdots & \ddots & \vdots \\ x_{32,1} & x_{32,2} & \cdots & x_{32,768} \end{bmatrix} \begin{array}{l} \leftarrow \text{样本1的768维特征} \\ \leftarrow \text{样本2的768维特征} \\ \\ \leftarrow \text{样本32的768维特征} \end{array}

BatchNorm(纵向归一化)

每一列(同一特征维度的所有样本)计算均值和方差:

μj=1Ni=1Nxi,j(第j维特征在所有样本上的均值)σj2=1Ni=1N(xi,jμj)2\begin{aligned} \mu_j &= \frac{1}{N} \sum_{i=1}^{N} x_{i,j} \quad \text{(第j维特征在所有样本上的均值)} \\ \sigma_j^2 &= \frac{1}{N} \sum_{i=1}^{N} (x_{i,j} - \mu_j)^2 \end{aligned}

可视化:

BatchNorm: 对每一列归一化
        维度1  维度2  维度3  ... 维度768
样本1    x     x      x    ...  x
样本2    x     x      x    ...  x
样本3    x     x      x    ...  x
...
样本32   x     x      x    ...  x
         ↓     ↓      ↓         ↓
        μ₁    μ₂     μ₃   ...  μ₇₆₈  (对列求均值)

LayerNorm(横向归一化)

每一行(单个样本的所有特征维度)计算均值和方差:

μi=1dj=1dxi,j(第i个样本的所有维度的均值)σi2=1dj=1d(xi,jμi)2\begin{aligned} \mu_i &= \frac{1}{d} \sum_{j=1}^{d} x_{i,j} \quad \text{(第i个样本的所有维度的均值)} \\ \sigma_i^2 &= \frac{1}{d} \sum_{j=1}^{d} (x_{i,j} - \mu_i)^2 \end{aligned}

可视化:

LayerNorm: 对每一行归一化
        维度1  维度2  维度3  ... 维度768
样本1    x     x      x    ...  x      → μ₁ (对行求均值)
样本2    x     x      x    ...  x      → μ₂
样本3    x     x      x    ...  x      → μ₃
...
样本32   x     x      x    ...  x      → μ₃₂

具体数值例子

假设有3个样本,每个样本4维特征:

X=[123456782468]X = \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 2 & 4 & 6 & 8 \end{bmatrix}

BatchNorm计算(对每一列):

第1维:μ1=1+5+23=2.67标准化:[1,5,2][0.87,1.31,0.44]第2维:μ2=2+6+43=4.00标准化:[2,6,4][1.00,1.00,0.00]第3维:μ3=3+7+63=5.33标准化:[3,7,6][1.07,0.93,0.13]第4维:μ4=4+8+83=6.67标准化:[4,8,8][1.15,0.58,0.58]\begin{aligned} \text{第1维:} \quad &\mu_1 = \frac{1+5+2}{3} = 2.67 \\ &\text{标准化:} [1, 5, 2] \to [-0.87, 1.31, -0.44] \\ \\ \text{第2维:} \quad &\mu_2 = \frac{2+6+4}{3} = 4.00 \\ &\text{标准化:} [2, 6, 4] \to [-1.00, 1.00, 0.00] \\ \\ \text{第3维:} \quad &\mu_3 = \frac{3+7+6}{3} = 5.33 \\ &\text{标准化:} [3, 7, 6] \to [-1.07, 0.93, 0.13] \\ \\ \text{第4维:} \quad &\mu_4 = \frac{4+8+8}{3} = 6.67 \\ &\text{标准化:} [4, 8, 8] \to [-1.15, 0.58, 0.58] \end{aligned}

结果:每个特征维度在batch中被归一化

LayerNorm计算(对每一行):

样本1:μ1=1+2+3+44=2.5标准化:[1,2,3,4][1.34,0.45,0.45,1.34]样本2:μ2=5+6+7+84=6.5标准化:[5,6,7,8][1.34,0.45,0.45,1.34]样本3:μ3=2+4+6+84=5.0标准化:[2,4,6,8][1.34,0.45,0.45,1.34]\begin{aligned} \text{样本1:} \quad &\mu_1 = \frac{1+2+3+4}{4} = 2.5 \\ &\text{标准化:} [1, 2, 3, 4] \to [-1.34, -0.45, 0.45, 1.34] \\ \\ \text{样本2:} \quad &\mu_2 = \frac{5+6+7+8}{4} = 6.5 \\ &\text{标准化:} [5, 6, 7, 8] \to [-1.34, -0.45, 0.45, 1.34] \\ \\ \text{样本3:} \quad &\mu_3 = \frac{2+4+6+8}{4} = 5.0 \\ &\text{标准化:} [2, 4, 6, 8] \to [-1.34, -0.45, 0.45, 1.34] \end{aligned}

结果:每个样本内部的特征被归一化

关键区别总结

维度BatchNormLayerNorm
归一化方向纵向(跨样本)横向(跨特征)
均值/方差计算同一特征在batch中的统计同一样本的所有特征的统计
依赖关系依赖batch中的其他样本只依赖当前样本自己
batch大小影响很大(小batch效果差)无影响(每个样本独立)
训练vs推理不一致(推理用移动平均)一致(相同计算)
适用场景CV(图像、batch稳定)NLP(序列、batch不稳定)

为什么Transformer用LayerNorm?

1. 序列长度可变

NLP任务中,不同句子长度差异很大:

样本1: "你好" (2个Token)
样本2: "今天天气真好,我们一起去公园玩吧" (14个Token)

如果用BatchNorm:

  • 需要padding或truncate到相同长度
  • padding的Token会影响统计量(需要mask)
  • 实现复杂,效果不稳定

如果用LayerNorm:

  • 每个样本独立计算,长度无关
  • 无需padding的特殊处理
  • 简单高效

2. Batch统计不稳定

Transformer训练时:

  • batch大小通常较小(2-32,因为序列长)
  • 不同batch的序列长度、内容差异大
  • BatchNorm的统计量方差很大

LayerNorm避免了这个问题:每个样本自己归一化,不受batch影响。

3. 训练与推理一致

BatchNorm的推理问题

训练时:

μtrain=1Ni=1Nxi(当前batch的均值)\mu_{\text{train}} = \frac{1}{N} \sum_{i=1}^{N} x_i \quad \text{(当前batch的均值)}

推理时(batch=1):

  • 不能用单个样本的统计(方差为0!)
  • 必须使用训练时积累的移动平均
μtest=moving_average(μtrain)\mu_{\text{test}} = \text{moving\_average}(\mu_{\text{train}})

这导致训练和推理行为不一致!

LayerNorm的一致性

训练时和推理时使用相同的公式:

μ=1di=1dxi(当前样本自己的均值)\mu = \frac{1}{d} \sum_{i=1}^{d} x_i \quad \text{(当前样本自己的均值)}

完全一致,没有移动平均的复杂性。

实际应用

领域常用归一化原因
图像分类(CNN)BatchNorm固定大小、batch稳定、通道维度有意义
目标检测BatchNorm / GroupNorm固定大小,但小batch时用GroupNorm
语言模型(Transformer)LayerNorm序列长度可变、batch小
语音识别LayerNorm序列长度可变
强化学习LayerNormbatch概念弱

现代趋势

即使在CV领域,也有向LayerNorm或GroupNorm转变的趋势(如Vision Transformer),因为:

  • 更容易迁移到不同batch大小
  • 训练推理一致
  • 分布式训练更简单(不需要跨GPU同步batch统计)

Post-LN vs Pre-LN:放在哪里更好?

LayerNorm在Transformer中的位置有两种方案,效果差异很大。

Post-LN(原始Transformer)

结构:先做变换,后归一化

注意力子层:X1=LayerNorm(X+Attention(X))MLP子层:X2=LayerNorm(X1+MLP(X1))\begin{aligned} \text{注意力子层:} \quad &X_1 = \text{LayerNorm}(X + \text{Attention}(X)) \\ \text{MLP子层:} \quad &X_2 = \text{LayerNorm}(X_1 + \text{MLP}(X_1)) \end{aligned}

流程图

X → [Attention][+ (残差)][LayerNorm] → X₁
X₁ → [MLP][+ (残差)][LayerNorm] → X₂

特点

  • ✅ 原始Transformer论文的方案
  • ✅ 理论上更符合ResNet的设计
  • ❌ 训练不稳定,需要warmup
  • ❌ 深层网络(>12层)容易梯度爆炸

问题分析

残差相加后,值的范围可能很大,然后才归一化。在深层网络中,累积效应会导致:

X+f(X)X\|X + f(X)\| \gg \|X\|

梯度在反向传播时可能放大,导致训练不稳定。

Pre-LN(现代Transformer)

结构:先归一化,后做变换

注意力子层:X1=X+Attention(LayerNorm(X))MLP子层:X2=X1+MLP(LayerNorm(X1))\begin{aligned} \text{注意力子层:} \quad &X_1 = X + \text{Attention}(\text{LayerNorm}(X)) \\ \text{MLP子层:} \quad &X_2 = X_1 + \text{MLP}(\text{LayerNorm}(X_1)) \end{aligned}

流程图

X → [LayerNorm][Attention][+ (残差)] → X₁
X₁ → [LayerNorm][MLP][+ (残差)] → X₂

特点

  • ✅ 训练更稳定,不需要warmup
  • ✅ 可以训练更深的网络(100+层)
  • ✅ 梯度更平滑
  • ⚠️ 理论上可能略损失一点性能(但实践中差异很小)

为什么更稳定?

  1. 归一化在变换前

    • 每个子层的输入都经过归一化
    • 激活值范围稳定在合理区间
    • 不会因为深度增加而爆炸
  2. 残差连接更直接

    • 原始信息直接加到子层输出
    • 梯度传播路径更清晰

对比总结

特性Post-LNPre-LN
归一化位置残差相加之后子层输入之前
训练稳定性较差,需要warmup好,不需要warmup
适用深度适合浅层(<24层)适合深层(100+层)
学习率敏感度高,需仔细调整低,更鲁棒
使用模型原始Transformer, BERTGPT-2/3/4, LLaMA

现代趋势

  • GPT-2开始采用Pre-LN
  • GPT-3、GPT-4:Pre-LN
  • LLaMA系列:Pre-LN
  • 几乎所有新的大模型:Pre-LN

原因:规模越来越大(从几层到上百层),稳定性比理论上的小幅性能差异更重要。

完整的Transformer层

综合所有组件,一个完整的Transformer层(Pre-LN版本):

步骤1:LayerNorm + 多头注意力 + 残差X1=X+MultiHeadAttention(LayerNorm(X))步骤2:LayerNorm + MLP + 残差X2=X1+MLP(LayerNorm(X1))\begin{aligned} \text{步骤1:LayerNorm + 多头注意力 + 残差} \\ X_1 &= X + \text{MultiHeadAttention}(\text{LayerNorm}(X)) \\ \\ \text{步骤2:LayerNorm + MLP + 残差} \\ X_2 &= X_1 + \text{MLP}(\text{LayerNorm}(X_1)) \end{aligned}

详细展开

// 注意力子层X^=LayerNorm(X)Q,K,V=X^WQ,X^WK,X^WVAttn=softmax(QKTdk)VX1=X+Attn(残差连接)// MLP子层X^1=LayerNorm(X1)h=Activation(W1X^1+b1)MLP_out=W2h+b2X2=X1+MLP_out(残差连接)\begin{aligned} &\text{// 注意力子层} \\ &\hat{X} = \text{LayerNorm}(X) \\ &Q, K, V = \hat{X} W_Q, \hat{X} W_K, \hat{X} W_V \\ &\text{Attn} = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V \\ &X_1 = X + \text{Attn} \quad \text{(残差连接)} \\ \\ &\text{// MLP子层} \\ &\hat{X}_1 = \text{LayerNorm}(X_1) \\ &h = \text{Activation}(W_1 \hat{X}_1 + b_1) \\ &\text{MLP\_out} = W_2 h + b_2 \\ &X_2 = X_1 + \text{MLP\_out} \quad \text{(残差连接)} \end{aligned}

代码实现

import torch
import torch.nn as nn

class TransformerBlock(nn.Module):
    """
    Pre-LN版本的Transformer块
    """
    def __init__(self, d_model=768, n_heads=12, d_ff=3072):
        super().__init__()

        # 两个LayerNorm
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

        # 多头注意力
        self.attention = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=n_heads,
            batch_first=True
        )

        # MLP(两层全连接)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x):
        """
        Args:
            x: shape (batch, seq_len, d_model)
        Returns:
            output: shape (batch, seq_len, d_model)
        """
        # 子层1:LayerNorm → Attention → 残差
        x_norm = self.ln1(x)
        attn_out, _ = self.attention(x_norm, x_norm, x_norm)
        x = x + attn_out  # 残差连接

        # 子层2:LayerNorm → MLP → 残差
        x_norm = self.ln2(x)
        mlp_out = self.mlp(x_norm)
        x = x + mlp_out  # 残差连接

        return x

# 测试
model = TransformerBlock(d_model=768, n_heads=12, d_ff=3072)
x = torch.randn(2, 10, 768)  # (batch=2, seq_len=10, d_model=768)
output = model(x)
print(f"输入 shape: {x.shape}")
print(f"输出 shape: {output.shape}")
print(f"维度保持不变: {x.shape == output.shape}")

信息流可视化

让我们追踪一个Token通过Transformer层的完整流程:

初始输入 x: [0.5, -0.3, 0.8, ..., 0.2]  (d_model维)
           ↓
      [LayerNorm]  归一化到均值0方差1
           ↓
      [Attention]  与其他Token交互,学习上下文
           ↓
   x + Attention   残差连接,保留原始信息
           ↓
        x₁: [0.6, -0.2, 0.9, ..., 0.3]  (加入了上下文信息)
           ↓
      [LayerNorm]  再次归一化
           ↓
         [MLP]     非线性变换,学习复杂模式
           ↓
     x₁ + MLP     残差连接,保留前面的信息
           ↓
        x₂: [0.7, -0.1, 1.0, ..., 0.4]  (最终输出)

关键点

  1. 每次变换后都有残差连接,保证信息不丢失
  2. 每次变换前都有LayerNorm,保证数值稳定
  3. 最终输出融合了:原始输入 + 上下文信息 + 非线性特征

实验:残差连接的重要性

让我们通过一个简单实验看看残差连接的效果。

实验设置

训练一个10层的小型Transformer:

  • 有残差连接版本
  • 无残差连接版本
# 无残差版本(会失败)
class BadTransformerBlock(nn.Module):
    def forward(self, x):
        x_norm = self.ln1(x)
        attn_out, _ = self.attention(x_norm, x_norm, x_norm)
        x = attn_out  # 没有残差!

        x_norm = self.ln2(x)
        mlp_out = self.mlp(x_norm)
        x = mlp_out  # 没有残差!

        return x

# 有残差版本(会成功)
class GoodTransformerBlock(nn.Module):
    def forward(self, x):
        x_norm = self.ln1(x)
        attn_out, _ = self.attention(x_norm, x_norm, x_norm)
        x = x + attn_out  # 有残差!

        x_norm = self.ln2(x)
        mlp_out = self.mlp(x_norm)
        x = x + mlp_out  # 有残差!

        return x

实验结果

指标无残差有残差
训练loss收敛❌ 不收敛✅ 正常收敛
梯度范数爆炸或消失稳定
最终准确率接近随机85%+
训练稳定性发散稳定

观察到的现象(无残差版本):

  • 前几层梯度消失(接近0)
  • 后几层梯度爆炸(>1000)
  • Loss曲线剧烈震荡
  • 最终无法学习到有用的表示

结论:对于10层以上的网络,残差连接是必需的,不是可选的!

小结

  1. 深度网络的困境

    • 梯度消失:连乘导致底层梯度趋近于0
    • 梯度爆炸:连乘导致梯度指数增长
    • 信息退化:深度变换导致原始信息丢失
  2. 残差连接的作用

    • 公式:Xout=Xin+f(Xin)X_{\text{out}} = X_{\text{in}} + f(X_{\text{in}})
    • 提供梯度的"高速通道":梯度可以直接传播
    • 保留原始信息:输出始终包含输入
    • 创造多条并行路径:指数级的信息流路径
  3. LayerNorm的作用

    • 归一化每个样本的所有特征维度
    • 稳定激活值范围,防止数值问题
    • 加速训练收敛
    • 公式:LN(x)=γxμσ2+ϵ+β\text{LN}(x) = \gamma \frac{x-\mu}{\sqrt{\sigma^2+\epsilon}} + \beta
  4. Pre-LN vs Post-LN

    • Post-LN:先变换后归一化,训练不稳定
    • Pre-LN:先归一化后变换,训练稳定
    • 现代大模型(GPT-3/4、LLaMA)都用Pre-LN
    • Pre-LN让100+层的超深网络成为可能
  5. 组合效果

    • 残差连接 + LayerNorm = 稳定的深度训练
    • 没有这两项技术,就没有今天的大模型
    • GPT-3(96层)、PaLM(118层)都依赖这些技术

历史意义

  • ResNet(2015):证明残差连接的有效性
  • LayerNorm(2016):为Transformer提供稳定性
  • Pre-LN(2018-2019):让超深Transformer成为可能
  • 这些看似简单的技术,是大模型革命的基石!