从0开始学AI:残差连接,原来是这回事!

14 阅读7分钟

概念

残差连接是指在神经网络中,下一层的输入包括上一层输出,与上一层的输入之和。

无残差连接:神经网络 2 的输入 = 神经网络层 1 的输出 y

有残差连接:神经网络 2 的输入 = 神经网络层 1 的输出 y + 神经网络层 1 的输入 x

作用

在介绍残差连接的作用前,我们先来看个例子:

没有残差连接

神经网络层 1:y=F(x)

假设:F(x) = 0.5 * x

1、 前向传播:

输入:x=10

输出:y=0.5*10=5

2、 反向传播:

假设:∂L/∂y = 1

计算: ∂L/∂x = ∂L/∂y × ∂y/∂x

∂y/∂x = 0.5

∂L/∂x = 1 × 0.5 = 0.5

发现问题:梯度从 1 衰减到 0.5

有残差连接

公式: y = F(x) + x

假设: F(x) = 0.5 × x

1、 前向传播:

输入: x = 10

F(x) = 0.5 × 10 = 5

输出: y = 5 + 10 = 15

2、 反向传播:

假设: ∂L/∂y = 1

计算: ∂L/∂x = ∂L/∂y × ∂y/∂x

∂y/∂x = ∂F(x)/∂x + ∂x/∂x

∂y/∂x = 0.5 + 1 = 1.5

∂L/∂x = 1 × 1.5 = 1.5

关键: ∂x/∂x = 1,梯度直接传递!

对比分析

没有残差连接: ∂L/∂x = 0.5

有残差连接: ∂L/∂x = 1.5

比值: 1.5 / 0.5 = 3.0

结论

  • 没有残差连接: 梯度衰减50%

  • 有残差连接: 梯度增强50%

  • 原因: 残差连接的"+1"让梯度直接传递

通过上述例子,可以发现如果神经网络层数较多时,残差连接可以解决梯度消失问题,保留原始信息,从而加速训练收敛。

实现

"""
残差连接的实现和演示
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    """
    残差块
    
    结构:
    - 线性变换
    - 激活函数
    - 线性变换
    - 残差连接: output = F(x) + x
    """
    def __init__(self, d_model):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_model)
        self.activation = nn.ReLU()
        self.linear2 = nn.Linear(d_model, d_model)

    def forward(self, x):
        """
        前向传播
        
        Args:
            x: 输入张量 (batch_size, seq_len, d_model)
        
        Returns:
            output: 输出张量 (batch_size, seq_len, d_model)
        """
        # 保存输入(用于残差连接)
        identity = x

        # 神经网络层
        out = self.linear1(x)
        out = self.activation(out)
        out = self.linear2(out)

        # 残差连接
        output = out + identity

        return output

def demo_residual_connection():
    """
    逐步演示残差连接的计算过程
    """
    print("\n=== 残差连接计算过程演示 ===\n")

    # 创建输入数据
    batch_size = 2
    seq_len = 3
    d_model = 4

    x = torch.randn(batch_size, seq_len, d_model)
    print(f"1. 输入数据:")
    print(f"   形状: {x.shape}")  # (2, 3, 4)
    print(f"   数据:")
    print(f"   {x}\n")

    # 创建残差块
    residual_block = ResidualBlock(d_model)
    residual_block.eval()

    print(f"2. 残差块处理:")

    # 第一个线性变换
    out1 = residual_block.linear1(x)
    print(f"   步骤1: 线性变换1")
    print(f"   输入形状: {x.shape}")  # (2, 3, 4)
    print(f"   输出形状: {out1.shape}")  # (2, 3, 4)
    print(f"   数据:")
    print(f"   {out1}\n")

    # 激活函数
    out2 = residual_block.activation(out1)
    print(f"   步骤2: ReLU激活")
    print(f"   输入形状: {out1.shape}")  # (2, 3, 4)
    print(f"   输出形状: {out2.shape}")  # (2, 3, 4)
    print(f"   数据:")
    print(f"   {out2}\n")

    # 第二个线性变换
    out3 = residual_block.linear2(out2)
    print(f"   步骤3: 线性变换2")
    print(f"   输入形状: {out2.shape}")  # (2, 3, 4)
    print(f"   输出形状: {out3.shape}")  # (2, 3, 4)
    print(f"   数据:")
    print(f"   {out3}\n")

    # 残差连接
    output = out3 + x
    print(f"   步骤4: 残差连接 (output = F(x) + x)")
    print(f"   F(x)形状: {out3.shape}")  # (2, 3, 4)
    print(f"   x形状: {x.shape}")  # (2, 3, 4)
    print(f"   输出形状: {output.shape}")  # (2, 3, 4)
    print(f"   数据:")
    print(f"   {output}\n")

    print(f"3. 关键观察:")
    print(f"   - 输入x直接加到输出上")
    print(f"   - 梯度可以直接通过x反向传播")
    print(f"   - 即使F(x)学错了,还能保留x的信息")

def demo_with_without_residual():
    """
    对比有残差连接和没有残差连接
    """
    print("\n=== 有残差连接 vs 没有残差连接 ===\n")

    # 创建输入
    x = torch.randn(1, 3, 4)
    print(f"1. 输入数据:")
    print(f"   形状: {x.shape}")  # (1, 3, 4)
    print(f"   数据: {x[0]}\n")

    # 没有残差连接的网络
    print(f"2. 没有残差连接的网络:")
    class NoResidual(nn.Module):
        def __init__(self, d_model):
            super().__init__()
            self.linear1 = nn.Linear(d_model, d_model)
            self.activation = nn.ReLU()
            self.linear2 = nn.Linear(d_model, d_model)
        
        def forward(self, x):
            out = self.linear1(x)
            out = self.activation(out)
            out = self.linear2(out)
            return out
    
    no_residual = NoResidual(4)
    no_residual.eval()
    output_no_residual = no_residual(x)
    
    print(f"   输入: {x[0].tolist()}")
    print(f"   输出: {output_no_residual[0].tolist()}")
    print(f"   差异: {(output_no_residual[0] - x[0]).tolist()}\n")
    
    # 有残差连接的网络
    print(f"3. 有残差连接的网络:")
    with_residual = ResidualBlock(4)
    with_residual.eval()
    output_with_residual = with_residual(x)
    
    print(f"   输入: {x[0].tolist()}")
    print(f"   输出: {output_with_residual[0].tolist()}")
    print(f"   差异: {(output_with_residual[0] - x[0]).tolist()}\n")
    
    print(f"4. 对比:")
    print(f"   没有残差: 输出可能与输入差异很大")
    print(f"   有残差: 输出与输入更接近(保留原始信息)")

def demo_gradient_flow():
    """
    演示残差连接对梯度流动的影响
    """
    print("\n=== 残差连接对梯度流动的影响 ===\n")
    
    # 创建输入
    x = torch.randn(1, 4, requires_grad=True)
    print(f"1. 输入数据:")
    print(f"   形状: {x.shape}")  # (1, 4)
    print(f"   数据: {x.data}\n")
    
    # 没有残差连接
    print(f"2. 没有残差连接的梯度流动:")
    class NoResidualDeep(nn.Module):
        def __init__(self, d_model):
            super().__init__()
            self.layers = nn.ModuleList([
                nn.Linear(d_model, d_model) for _ in range(5)
            ])
            self.activation = nn.ReLU()
        
        def forward(self, x):
            for layer in self.layers:
                x = self.activation(layer(x))
            return x
    
    no_residual_deep = NoResidualDeep(4)
    no_residual_deep.eval()
    
    # 前向传播
    output_no_residual = no_residual_deep(x)
    loss_no_residual = output_no_residual.sum()
    
    # 反向传播
    loss_no_residual.backward()
    grad_no_residual = x.grad.clone()
    print(f"   输入梯度: {grad_no_residual}")
    print(f"   梯度范数: {grad_no_residual.norm().item():.6f}\n")
    
    # 清零梯度
    x.grad.zero_()
    
    # 有残差连接
    print(f"3. 有残差连接的梯度流动:")
    class ResidualDeep(nn.Module):
        def __init__(self, d_model):
            super().__init__()
            self.layers = nn.ModuleList([
                ResidualBlock(d_model) for _ in range(5)
            ])
        
        def forward(self, x):
            for layer in self.layers:
                x = layer(x)
            return x
    
    residual_deep = ResidualDeep(4)
    residual_deep.eval()
    
    # 前向传播
    output_residual = residual_deep(x)
    loss_residual = output_residual.sum()
    
    # 反向传播
    loss_residual.backward()
    grad_residual = x.grad.clone()
    print(f"   输入梯度: {grad_residual}")
    print(f"   梯度范数: {grad_residual.norm().item():.6f}\n")
    
    print(f"4. 对比:")
    print(f"   没有残差: 梯度范数 = {grad_no_residual.norm().item():.6f}")
    print(f"   有残差: 梯度范数 = {grad_residual.norm().item():.6f}")
    print(f"   比值: {grad_residual.norm().item() / grad_no_residual.norm().item():.2f}")
    print(f"   结论: 残差连接让梯度更稳定")

if __name__ == "__main__":
    # 演示残差连接的计算过程
    demo_residual_connection()
    
    # 对比有残差和没有残差
    demo_with_without_residual()
    
    # 演示梯度流动
    demo_gradient_flow()
残差连接计算过程演示 ===

1. 输入数据:
   形状: torch.Size([2, 3, 4])
   数据:
   tensor([[[ 1.2994,  0.1954,  0.3492, -0.7819],
         [-0.6024, -2.0126, -1.0499, -0.8324],
         [ 0.6048, -0.3601,  1.0400, -1.3421]],

        [[-0.4784, -0.7506,  0.2447,  0.3918],
         [-0.1755,  0.8695, -0.9142,  0.0345],
         [-0.9269, -0.6067,  1.1459, -1.7005]]])

2. 残差块处理:
   步骤1: 线性变换1
   输入形状: torch.Size([2, 3, 4])
   输出形状: torch.Size([2, 3, 4])
   数据:
   tensor([[[ 0.7658, -0.3300,  0.8028,  1.0989],
         [ 0.0258,  0.1114,  1.2917,  0.0621],
         [ 0.6713, -0.1527,  0.6998,  0.5079]],

        [[ 0.2758, -0.1290,  0.0225, -0.2566],
         [-0.2895,  0.2140,  0.2322,  1.0172],
         [-0.0051,  0.3609,  0.3521, -0.1336]]], grad_fn=<ViewBackward0>)

   步骤2: ReLU激活
   输入形状: torch.Size([2, 3, 4])
   输出形状: torch.Size([2, 3, 4])
   数据:
   tensor([[[0.7658, 0.0000, 0.8028, 1.0989],
         [0.0258, 0.1114, 1.2917, 0.0621],
         [0.6713, 0.0000, 0.6998, 0.5079]],

        [[0.2758, 0.0000, 0.0225, 0.0000],
         [0.0000, 0.2140, 0.2322, 1.0172],
         [0.0000, 0.3609, 0.3521, 0.0000]]], grad_fn=<ReluBackward0>)

   步骤3: 线性变换2
   输入形状: torch.Size([2, 3, 4])
   输出形状: torch.Size([2, 3, 4])
   数据:
   tensor([[[-0.0765,  0.1532,  0.7805,  0.3352],
         [-0.3025, -0.2275,  0.0328,  0.3363],
         [-0.1505,  0.1198,  0.5120,  0.0365]],

        [[-0.3037,  0.2851,  0.1828, -0.3490],
         [-0.2573,  0.2496,  0.4275,  0.4289],
         [-0.3283,  0.0641, -0.0141,  0.0448]]], grad_fn=<ViewBackward0>)

   步骤4: 残差连接 (output = F(x) + x)
   F(x)形状: torch.Size([2, 3, 4])
   x形状: torch.Size([2, 3, 4])
   输出形状: torch.Size([2, 3, 4])
   数据:
   tensor([[[ 1.2229,  0.3487,  1.1297, -0.4468],
         [-0.9049, -2.2400, -1.0171, -0.4961],
         [ 0.4543, -0.2403,  1.5520, -1.3056]],

        [[-0.7821, -0.4655,  0.4276,  0.0427],
         [-0.4327,  1.1191, -0.4867,  0.4634],
         [-1.2552, -0.5426,  1.1318, -1.6556]]], grad_fn=<AddBackward0>)

3. 关键观察:
   - 输入x直接加到输出上
   - 梯度可以直接通过x反向传播
   - 即使F(x)学错了,还能保留x的信息

=== 有残差连接 vs 没有残差连接 ===

1. 输入数据:
   形状: torch.Size([1, 3, 4])
   数据: tensor([[ 0.0318, -0.1835,  1.0661, -1.0343],
        [-0.4208, -1.5080,  1.2142, -0.0450],
        [ 1.1789, -0.8850,  0.3372,  0.1020]])

2. 没有残差连接的网络:
   输入: [[0.03177139163017273, -0.1834668666124344, 1.0660555362701416, -1.0343255996704102], [-0.4207763671875, -1.508028507232666, 1.2141791582107544, -0.044989679008722305], [1.1788984537124634, -0.8850030303001404, 0.33715730905532837, 0.10200556367635727]]
   输出: [[-0.5901569128036499, 0.21348334848880768, 0.2549939453601837, 0.2239282727241516], [-0.6048353314399719, 0.10298468172550201, 0.07541395723819733, 0.3999461531639099], [-0.6313197612762451, 0.0886533260345459, 0.25830820202827454, 0.4168025851249695]]
   差异: [[-0.621928334236145, 0.39695021510124207, -0.8110616207122803, 1.258253812789917], [-0.18405896425247192, 1.6110131740570068, -1.1387652158737183, 0.4449358284473419], [-1.8102182149887085, 0.9736563563346863, -0.07884910702705383, 0.3147970139980316]]

3. 有残差连接的网络:
   输入: [[0.03177139163017273, -0.1834668666124344, 1.0660555362701416, -1.0343255996704102], [-0.4207763671875, -1.508028507232666, 1.2141791582107544, -0.044989679008722305], [1.1788984537124634, -0.8850030303001404, 0.33715730905532837, 0.10200556367635727]]
   输出: [[0.11138532310724258, 0.03259594738483429, 1.2016692161560059, -1.1722708940505981], [-0.11688193678855896, -1.1672778129577637, 1.546790361404419, -0.3607446849346161], [1.4412567615509033, -0.793891429901123, 0.4376524090766907, 0.06655491143465042]]
   差异: [[0.07961393147706985, 0.21606281399726868, 0.13561367988586426, -0.137945294380188], [0.30389443039894104, 0.34075069427490234, 0.33261120319366455, -0.3157550096511841], [0.26235830783843994, 0.09111160039901733, 0.1004951000213623, -0.03545065224170685]]

4. 对比:
   没有残差: 输出可能与输入差异很大
   有残差: 输出与输入更接近(保留原始信息)

=== 残差连接对梯度流动的影响 ===

1. 输入数据:
   形状: torch.Size([1, 4])
   数据: tensor([[ 1.9369,  0.6621, -0.3083, -0.9251]])

2. 没有残差连接的梯度流动:
   输入梯度: tensor([[ 0.0006, -0.0003,  0.0005, -0.0010]])
   梯度范数: 0.001344

3. 有残差连接的梯度流动:
   输入梯度: tensor([[0.9176, 1.0480, 1.3874, 1.0185]])
   梯度范数: 2.214157

4. 对比:
   没有残差: 梯度范数 = 0.001344
   有残差: 梯度范数 = 2.214157
   比值: 1647.45
   结论: 残差连接让梯度更稳定

那么,什么情况可以不用残差连接呢,如上所示,可以不用残差连接的情况:

1、单层注意力:只有 1 层注意力层;

2、浅层网络:小于 5 层;

3、LSTM:LSTM 中有细胞状态,类似残差连接的作用,不需要额外的残差连接;

4、简单任务:如简单分类器;

应用

在实际使用中,残差连接通常和 LayerNormer 配合使用效果更好,LayerNormer 稳定梯度大小,残差连接保持梯度传递,梯度流动更顺畅。

具体例子:

假设:

-输入 x 的数值范围:[-10,10]

-F(x)的数值范围:[-100,100]

1、 没有 LayerNormer + 没有残差:

-输出:y=F(x)

-数值范围:[-100,100]

-梯度不稳定

2、 没有 LayerNormer + 有残差:

-输出:y=F(x)+x

-数值范围:[-110,110]

-数值范围更大,更不稳定

3、 有 LayerNormer + 没有残差:

-输出:y=LayerNormer(F(x))

数值范围:[-1,1]

-数值稳定,但可能丢失信息

4、 有 LayerNormer + 有残差:

-输出:y=LayerNormer(F(x)) + x

-LayerNormer(F(x))的数值范围:[-1,1]

-x 的数值范围:[-10,10]

-y 的数值范围:[-11,11]

-数值相对稳定,保留了 x 的信息

-梯度稳定,梯度传递顺畅