在深度学习中,训练深层神经网络常常充满挑战,尤其是需要快速收敛时。而 批量规范化(Batch Normalization, 简称BN) 是一种流行且有效的技术,能够显著加速网络训练。本文将通过通俗易懂的语言,结合实例和数学推导,为你深入解析批量规范化的原理和应用。
1. 为什么需要批量规范化?
-
数据预处理的重要性
在训练神经网络时,数据通常需要标准化,使其均值为0、方差为1。批量规范化通过规范化网络中每一层的中间变量,统一数据的分布,有助于优化器更高效地工作。
-
避免变量分布偏移
深度网络的中间层变量分布可能会随着训练变化,导致网络收敛困难。批量规范化可以主动调整这些变量的分布,减少偏移带来的影响。
-
正则化效果
批量规范化在一定程度上引入了噪声,这种噪声对减轻过拟合有帮助。
2. 批量规范化的原理
2.1 标准化公式
对一个小批量输入 ,批量规范化的步骤如下:
- 计算均值:
- 计算方差:
-
规范化:
- 在规范化公式中, 是一个非常重要的常数,用于防止除以零的错误。
- 通常, 取一个非常小的值,像 或 都是常见的选择。
- 总结来说, 的作用是避免数值不稳定,确保标准化操作顺利进行。
-
拉伸与平移:
引入两个可学习参数 和 ,对标准化后的数据进行线性变换:
这样可以保证网络在标准化后依然具有足够的表达能力。
2.2 噪声的正则化作用
在小批量上计算的均值和方差会引入噪声,但这种噪声能帮助模型更好地泛化,类似于正则化的效果。
3. 批量规范化的实现
3.1 全连接层中的批量规范化
在全连接层中,批量规范化一般插入在仿射变换和激活函数之间:
- 仿射变换:简而言之,就是一种线性变换,后面加上一个平移操作。
3.2 卷积层中的批量规范化
对于卷积层,批量规范化需要对每个通道独立计算均值和方差,同时保持空间维度(高度和宽度)的一致性。公式类似,但在空间维度 上进行归一化。
3.3 训练与推理模式
在训练过程中,BN使用小批量数据的均值和方差;而在推理阶段,则使用全局均值和方差,这些值通过移动平均的方法在训练时累积。
4. 批量规范化的PyTorch实现
4.1 我们从头开始实现一个具有张量的批量规范化层
import torch
from torch import nn
def batch_norm(X: torch.Tensor, gamma, beta, moving_mean, moving_var, eps, momentum):
"""
批量归一化实现函数
参数说明:
X: 输入数据张量
- 全连接层:形状为 [batch_size, num_features]
- 卷积层:形状为 [batch_size, num_channels, height, width]
- 说明:待归一化的输入数据。
gamma: 缩放参数(可学习的)
- 全连接层:形状为 [num_features]
- 卷积层:形状为 [num_channels]
- 说明:用于调整归一化后的数据的尺度,初始化时通常为 1。
beta: 偏移参数(可学习的)
- 全连接层:形状为 [num_features]
- 卷积层:形状为 [num_channels]
- 说明:用于调整归一化后的数据的偏移量,初始化时通常为 0。
moving_mean: 全局均值(非可学习)
- 全连接层:形状为 [num_features]
- 卷积层:形状为 [num_channels]
- 说明:训练时更新为当前批次均值的移动平均值,预测时直接使用。
moving_var: 全局方差(非可学习)
- 全连接层:形状为 [num_features]
- 卷积层:形状为 [num_channels]
- 说明:训练时更新为当前批次方差的移动平均值,预测时直接使用。
eps: 防止分母为零的小常数
- 类型:浮点数(float)
- 说明:用于数值稳定性,通常设置为一个很小的值,例如 1e-5 或 1e-8。
momentum: 动量参数
- 类型:浮点数(float)
- 说明:控制全局均值和方差的更新速度,常用值为 0.9 或 0.99。
"""
# 通过is_grad_enabled来判断当前模式是训练模式还是预测模式
if not torch.is_grad_enabled():
# 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差
X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
else:
assert len(X.shape) in (2, 4)
if len(X.shape) == 2:
# 使用全连接层的情况,计算特征维上的均值和方差
mean = X.mean(dim=0) # 均值
var = ((X - mean) ** 2).mean(dim=0) # 方差
else:
# 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。
# 这里我们需要保持X的形状以便后面可以做广播运算
mean = X.mean(dim=(0, 2, 3), keepdim=True)
var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
# 训练模式下,用当前的均值和方差做标准化
X_hat = (X - mean) / torch.sqrt(var + eps)
# 更新移动平均的均值和方差
moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
moving_var = momentum * moving_var + (1.0 - momentum) * var
# 对归一化后的数据进行缩放和偏移
Y = gamma * X_hat + beta
return Y, moving_mean.data, moving_var.data
-
在
batch_norm函数中,len(X.shape)必须是 或 的原因是函数专为以下两种情况设计的:-
全连接层的输入 (len(X.shape) == 2)
- 输入形状:
[batch_size, num_features] - 在这种情况下,对每个特征维度(
dim=0)计算均值和方差进行归一化
- 输入形状:
-
二维卷积层的输入 (len(X.shape) == 4)
- 输入形状:
[batch_size, num_channels, height, width] - 在这种情况下,通常对每个通道维度(
dim=(0, 2, 3))计算均值和方差进行归一化 dim=(0, 2, 3)表示对所有样本(批次)和特征图的每个位置取平均,只保留通道维度的统计信息。
- 输入形状:
-
低于 2 维的数据表示不够具体,难以区分批次和特征维;超出二维卷积的应用场景,通常需要专门的归一化方法(例如 3D 卷积的批量归一化)。
-
-
在批量归一化中,
moving_mean表示全局均值,训练时通过每个小批量数据的均值更新得到。这种计算方式是 指数加权移动平均(Exponential Moving Average, EMA),其核心目的是平滑历史均值的更新,避免对单个批次数据的均值过于敏感,同时反映历史和当前批次的贡献。momentum:表示历史均值的权重(“记忆”程度)。momentum越接近 1,历史均值的权重越大,新均值的影响越小;momentum越接近 0,当前批次的均值影响越大。1−momentum:表示当前批次均值对更新的贡献权重。
4.2 现在可以创建一个正确的BatchNorm层
class BatchNorm(nn.Module):
def __init__(self, num_features, num_dims):
# num_features:完全连接层的输出数量或卷积层的输出通道数。
# num_dims:2表示完全连接层,4表示卷积层
super().__init__()
if num_dims == 2:
shape = (1, num_features)
else:
shape = (1, num_features, 1, 1)
# 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0
self.gamma = nn.Parameter(torch.ones(shape))
self.beta = nn.Parameter(torch.zeros(shape))
# 非模型参数的变量初始化为0和1
self.moving_mean = torch.zeros(shape)
self.moving_var = torch.ones(shape)
def forward(self, X):
# 如果X不在内存上,将moving_mean和moving_var
# 复制到X所在显存上
if self.moving_mean.device != X.device:
self.moving_mean = self.moving_mean.to(X.device)
self.moving_var = self.moving_var.to(X.device)
# 保存更新过的moving_mean和moving_var
Y, self.moving_mean, self.moving_var = batch_norm(X, self.gamma, self.beta,
self.moving_mean, self.moving_var,
eps=1e-5, momentum=0.9)
return Y
5. 应用实例:带BN的LeNet模型
net = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
nn.Linear(16 * 4 * 4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),
nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),
nn.Linear(84, 10)
)
import d2l
lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(256)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
简明实现:
net = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
nn.Linear(16 * 4 * 4, 120), nn.BatchNorm1d(120), nn.Sigmoid(),
nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),
nn.Linear(84, 10)
)
6. 批量规范化的优势与挑战
-
优势:
- 加速收敛:减少梯度消失和梯度爆炸问题。
- 提升稳定性:有效缓解内部协变量偏移。
- 正则化作用:在一定程度上减轻过拟合。
-
挑战:
- 计算开销:小批量计算的均值和方差会增加训练时间。
- 小批量问题:当批量大小较小时,均值和方差的估计可能不准确。
7. 批量规范化的未来发展
尽管BN已经成为深度学习的基础组件,但研究者们仍在探索更优的归一化方法。例如,层归一化(Layer Normalization)、实例归一化(Instance Normalization)和组归一化(Group Normalization)等技术在不同场景下提供了更好的性能。