概念
残差连接是指在神经网络中,下一层的输入包括上一层输出,与上一层的输入之和。
无残差连接:神经网络 2 的输入 = 神经网络层 1 的输出 y
有残差连接:神经网络 2 的输入 = 神经网络层 1 的输出 y + 神经网络层 1 的输入 x
作用
在介绍残差连接的作用前,我们先来看个例子:
没有残差连接
神经网络层 1:y=F(x)
假设:F(x) = 0.5 * x
1、 前向传播:
输入:x=10
输出:y=0.5*10=5
2、 反向传播:
假设:∂L/∂y = 1
计算: ∂L/∂x = ∂L/∂y × ∂y/∂x
∂y/∂x = 0.5
∂L/∂x = 1 × 0.5 = 0.5
发现问题:梯度从 1 衰减到 0.5
有残差连接
公式: y = F(x) + x
假设: F(x) = 0.5 × x
1、 前向传播:
输入: x = 10
F(x) = 0.5 × 10 = 5
输出: y = 5 + 10 = 15
2、 反向传播:
假设: ∂L/∂y = 1
计算: ∂L/∂x = ∂L/∂y × ∂y/∂x
∂y/∂x = ∂F(x)/∂x + ∂x/∂x
∂y/∂x = 0.5 + 1 = 1.5
∂L/∂x = 1 × 1.5 = 1.5
关键: ∂x/∂x = 1,梯度直接传递!
对比分析
没有残差连接: ∂L/∂x = 0.5
有残差连接: ∂L/∂x = 1.5
比值: 1.5 / 0.5 = 3.0
结论
-
没有残差连接: 梯度衰减50%
-
有残差连接: 梯度增强50%
-
原因: 残差连接的"+1"让梯度直接传递
通过上述例子,可以发现如果神经网络层数较多时,残差连接可以解决梯度消失问题,保留原始信息,从而加速训练收敛。
实现
"""
残差连接的实现和演示
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
"""
残差块
结构:
- 线性变换
- 激活函数
- 线性变换
- 残差连接: output = F(x) + x
"""
def __init__(self, d_model):
super().__init__()
self.linear1 = nn.Linear(d_model, d_model)
self.activation = nn.ReLU()
self.linear2 = nn.Linear(d_model, d_model)
def forward(self, x):
"""
前向传播
Args:
x: 输入张量 (batch_size, seq_len, d_model)
Returns:
output: 输出张量 (batch_size, seq_len, d_model)
"""
# 保存输入(用于残差连接)
identity = x
# 神经网络层
out = self.linear1(x)
out = self.activation(out)
out = self.linear2(out)
# 残差连接
output = out + identity
return output
def demo_residual_connection():
"""
逐步演示残差连接的计算过程
"""
print("\n=== 残差连接计算过程演示 ===\n")
# 创建输入数据
batch_size = 2
seq_len = 3
d_model = 4
x = torch.randn(batch_size, seq_len, d_model)
print(f"1. 输入数据:")
print(f" 形状: {x.shape}") # (2, 3, 4)
print(f" 数据:")
print(f" {x}\n")
# 创建残差块
residual_block = ResidualBlock(d_model)
residual_block.eval()
print(f"2. 残差块处理:")
# 第一个线性变换
out1 = residual_block.linear1(x)
print(f" 步骤1: 线性变换1")
print(f" 输入形状: {x.shape}") # (2, 3, 4)
print(f" 输出形状: {out1.shape}") # (2, 3, 4)
print(f" 数据:")
print(f" {out1}\n")
# 激活函数
out2 = residual_block.activation(out1)
print(f" 步骤2: ReLU激活")
print(f" 输入形状: {out1.shape}") # (2, 3, 4)
print(f" 输出形状: {out2.shape}") # (2, 3, 4)
print(f" 数据:")
print(f" {out2}\n")
# 第二个线性变换
out3 = residual_block.linear2(out2)
print(f" 步骤3: 线性变换2")
print(f" 输入形状: {out2.shape}") # (2, 3, 4)
print(f" 输出形状: {out3.shape}") # (2, 3, 4)
print(f" 数据:")
print(f" {out3}\n")
# 残差连接
output = out3 + x
print(f" 步骤4: 残差连接 (output = F(x) + x)")
print(f" F(x)形状: {out3.shape}") # (2, 3, 4)
print(f" x形状: {x.shape}") # (2, 3, 4)
print(f" 输出形状: {output.shape}") # (2, 3, 4)
print(f" 数据:")
print(f" {output}\n")
print(f"3. 关键观察:")
print(f" - 输入x直接加到输出上")
print(f" - 梯度可以直接通过x反向传播")
print(f" - 即使F(x)学错了,还能保留x的信息")
def demo_with_without_residual():
"""
对比有残差连接和没有残差连接
"""
print("\n=== 有残差连接 vs 没有残差连接 ===\n")
# 创建输入
x = torch.randn(1, 3, 4)
print(f"1. 输入数据:")
print(f" 形状: {x.shape}") # (1, 3, 4)
print(f" 数据: {x[0]}\n")
# 没有残差连接的网络
print(f"2. 没有残差连接的网络:")
class NoResidual(nn.Module):
def __init__(self, d_model):
super().__init__()
self.linear1 = nn.Linear(d_model, d_model)
self.activation = nn.ReLU()
self.linear2 = nn.Linear(d_model, d_model)
def forward(self, x):
out = self.linear1(x)
out = self.activation(out)
out = self.linear2(out)
return out
no_residual = NoResidual(4)
no_residual.eval()
output_no_residual = no_residual(x)
print(f" 输入: {x[0].tolist()}")
print(f" 输出: {output_no_residual[0].tolist()}")
print(f" 差异: {(output_no_residual[0] - x[0]).tolist()}\n")
# 有残差连接的网络
print(f"3. 有残差连接的网络:")
with_residual = ResidualBlock(4)
with_residual.eval()
output_with_residual = with_residual(x)
print(f" 输入: {x[0].tolist()}")
print(f" 输出: {output_with_residual[0].tolist()}")
print(f" 差异: {(output_with_residual[0] - x[0]).tolist()}\n")
print(f"4. 对比:")
print(f" 没有残差: 输出可能与输入差异很大")
print(f" 有残差: 输出与输入更接近(保留原始信息)")
def demo_gradient_flow():
"""
演示残差连接对梯度流动的影响
"""
print("\n=== 残差连接对梯度流动的影响 ===\n")
# 创建输入
x = torch.randn(1, 4, requires_grad=True)
print(f"1. 输入数据:")
print(f" 形状: {x.shape}") # (1, 4)
print(f" 数据: {x.data}\n")
# 没有残差连接
print(f"2. 没有残差连接的梯度流动:")
class NoResidualDeep(nn.Module):
def __init__(self, d_model):
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(d_model, d_model) for _ in range(5)
])
self.activation = nn.ReLU()
def forward(self, x):
for layer in self.layers:
x = self.activation(layer(x))
return x
no_residual_deep = NoResidualDeep(4)
no_residual_deep.eval()
# 前向传播
output_no_residual = no_residual_deep(x)
loss_no_residual = output_no_residual.sum()
# 反向传播
loss_no_residual.backward()
grad_no_residual = x.grad.clone()
print(f" 输入梯度: {grad_no_residual}")
print(f" 梯度范数: {grad_no_residual.norm().item():.6f}\n")
# 清零梯度
x.grad.zero_()
# 有残差连接
print(f"3. 有残差连接的梯度流动:")
class ResidualDeep(nn.Module):
def __init__(self, d_model):
super().__init__()
self.layers = nn.ModuleList([
ResidualBlock(d_model) for _ in range(5)
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
residual_deep = ResidualDeep(4)
residual_deep.eval()
# 前向传播
output_residual = residual_deep(x)
loss_residual = output_residual.sum()
# 反向传播
loss_residual.backward()
grad_residual = x.grad.clone()
print(f" 输入梯度: {grad_residual}")
print(f" 梯度范数: {grad_residual.norm().item():.6f}\n")
print(f"4. 对比:")
print(f" 没有残差: 梯度范数 = {grad_no_residual.norm().item():.6f}")
print(f" 有残差: 梯度范数 = {grad_residual.norm().item():.6f}")
print(f" 比值: {grad_residual.norm().item() / grad_no_residual.norm().item():.2f}")
print(f" 结论: 残差连接让梯度更稳定")
if __name__ == "__main__":
# 演示残差连接的计算过程
demo_residual_connection()
# 对比有残差和没有残差
demo_with_without_residual()
# 演示梯度流动
demo_gradient_flow()
残差连接计算过程演示 ===
1. 输入数据:
形状: torch.Size([2, 3, 4])
数据:
tensor([[[ 1.2994, 0.1954, 0.3492, -0.7819],
[-0.6024, -2.0126, -1.0499, -0.8324],
[ 0.6048, -0.3601, 1.0400, -1.3421]],
[[-0.4784, -0.7506, 0.2447, 0.3918],
[-0.1755, 0.8695, -0.9142, 0.0345],
[-0.9269, -0.6067, 1.1459, -1.7005]]])
2. 残差块处理:
步骤1: 线性变换1
输入形状: torch.Size([2, 3, 4])
输出形状: torch.Size([2, 3, 4])
数据:
tensor([[[ 0.7658, -0.3300, 0.8028, 1.0989],
[ 0.0258, 0.1114, 1.2917, 0.0621],
[ 0.6713, -0.1527, 0.6998, 0.5079]],
[[ 0.2758, -0.1290, 0.0225, -0.2566],
[-0.2895, 0.2140, 0.2322, 1.0172],
[-0.0051, 0.3609, 0.3521, -0.1336]]], grad_fn=<ViewBackward0>)
步骤2: ReLU激活
输入形状: torch.Size([2, 3, 4])
输出形状: torch.Size([2, 3, 4])
数据:
tensor([[[0.7658, 0.0000, 0.8028, 1.0989],
[0.0258, 0.1114, 1.2917, 0.0621],
[0.6713, 0.0000, 0.6998, 0.5079]],
[[0.2758, 0.0000, 0.0225, 0.0000],
[0.0000, 0.2140, 0.2322, 1.0172],
[0.0000, 0.3609, 0.3521, 0.0000]]], grad_fn=<ReluBackward0>)
步骤3: 线性变换2
输入形状: torch.Size([2, 3, 4])
输出形状: torch.Size([2, 3, 4])
数据:
tensor([[[-0.0765, 0.1532, 0.7805, 0.3352],
[-0.3025, -0.2275, 0.0328, 0.3363],
[-0.1505, 0.1198, 0.5120, 0.0365]],
[[-0.3037, 0.2851, 0.1828, -0.3490],
[-0.2573, 0.2496, 0.4275, 0.4289],
[-0.3283, 0.0641, -0.0141, 0.0448]]], grad_fn=<ViewBackward0>)
步骤4: 残差连接 (output = F(x) + x)
F(x)形状: torch.Size([2, 3, 4])
x形状: torch.Size([2, 3, 4])
输出形状: torch.Size([2, 3, 4])
数据:
tensor([[[ 1.2229, 0.3487, 1.1297, -0.4468],
[-0.9049, -2.2400, -1.0171, -0.4961],
[ 0.4543, -0.2403, 1.5520, -1.3056]],
[[-0.7821, -0.4655, 0.4276, 0.0427],
[-0.4327, 1.1191, -0.4867, 0.4634],
[-1.2552, -0.5426, 1.1318, -1.6556]]], grad_fn=<AddBackward0>)
3. 关键观察:
- 输入x直接加到输出上
- 梯度可以直接通过x反向传播
- 即使F(x)学错了,还能保留x的信息
=== 有残差连接 vs 没有残差连接 ===
1. 输入数据:
形状: torch.Size([1, 3, 4])
数据: tensor([[ 0.0318, -0.1835, 1.0661, -1.0343],
[-0.4208, -1.5080, 1.2142, -0.0450],
[ 1.1789, -0.8850, 0.3372, 0.1020]])
2. 没有残差连接的网络:
输入: [[0.03177139163017273, -0.1834668666124344, 1.0660555362701416, -1.0343255996704102], [-0.4207763671875, -1.508028507232666, 1.2141791582107544, -0.044989679008722305], [1.1788984537124634, -0.8850030303001404, 0.33715730905532837, 0.10200556367635727]]
输出: [[-0.5901569128036499, 0.21348334848880768, 0.2549939453601837, 0.2239282727241516], [-0.6048353314399719, 0.10298468172550201, 0.07541395723819733, 0.3999461531639099], [-0.6313197612762451, 0.0886533260345459, 0.25830820202827454, 0.4168025851249695]]
差异: [[-0.621928334236145, 0.39695021510124207, -0.8110616207122803, 1.258253812789917], [-0.18405896425247192, 1.6110131740570068, -1.1387652158737183, 0.4449358284473419], [-1.8102182149887085, 0.9736563563346863, -0.07884910702705383, 0.3147970139980316]]
3. 有残差连接的网络:
输入: [[0.03177139163017273, -0.1834668666124344, 1.0660555362701416, -1.0343255996704102], [-0.4207763671875, -1.508028507232666, 1.2141791582107544, -0.044989679008722305], [1.1788984537124634, -0.8850030303001404, 0.33715730905532837, 0.10200556367635727]]
输出: [[0.11138532310724258, 0.03259594738483429, 1.2016692161560059, -1.1722708940505981], [-0.11688193678855896, -1.1672778129577637, 1.546790361404419, -0.3607446849346161], [1.4412567615509033, -0.793891429901123, 0.4376524090766907, 0.06655491143465042]]
差异: [[0.07961393147706985, 0.21606281399726868, 0.13561367988586426, -0.137945294380188], [0.30389443039894104, 0.34075069427490234, 0.33261120319366455, -0.3157550096511841], [0.26235830783843994, 0.09111160039901733, 0.1004951000213623, -0.03545065224170685]]
4. 对比:
没有残差: 输出可能与输入差异很大
有残差: 输出与输入更接近(保留原始信息)
=== 残差连接对梯度流动的影响 ===
1. 输入数据:
形状: torch.Size([1, 4])
数据: tensor([[ 1.9369, 0.6621, -0.3083, -0.9251]])
2. 没有残差连接的梯度流动:
输入梯度: tensor([[ 0.0006, -0.0003, 0.0005, -0.0010]])
梯度范数: 0.001344
3. 有残差连接的梯度流动:
输入梯度: tensor([[0.9176, 1.0480, 1.3874, 1.0185]])
梯度范数: 2.214157
4. 对比:
没有残差: 梯度范数 = 0.001344
有残差: 梯度范数 = 2.214157
比值: 1647.45
结论: 残差连接让梯度更稳定
那么,什么情况可以不用残差连接呢,如上所示,可以不用残差连接的情况:
1、单层注意力:只有 1 层注意力层;
2、浅层网络:小于 5 层;
3、LSTM:LSTM 中有细胞状态,类似残差连接的作用,不需要额外的残差连接;
4、简单任务:如简单分类器;
应用
在实际使用中,残差连接通常和 LayerNormer 配合使用效果更好,LayerNormer 稳定梯度大小,残差连接保持梯度传递,梯度流动更顺畅。
具体例子:
假设:
-输入 x 的数值范围:[-10,10]
-F(x)的数值范围:[-100,100]
1、 没有 LayerNormer + 没有残差:
-输出:y=F(x)
-数值范围:[-100,100]
-梯度不稳定
2、 没有 LayerNormer + 有残差:
-输出:y=F(x)+x
-数值范围:[-110,110]
-数值范围更大,更不稳定
3、 有 LayerNormer + 没有残差:
-输出:y=LayerNormer(F(x))
数值范围:[-1,1]
-数值稳定,但可能丢失信息
4、 有 LayerNormer + 有残差:
-输出:y=LayerNormer(F(x)) + x
-LayerNormer(F(x))的数值范围:[-1,1]
-x 的数值范围:[-10,10]
-y 的数值范围:[-11,11]
-数值相对稳定,保留了 x 的信息
-梯度稳定,梯度传递顺畅