一文讲清 nn.BatchNorm ​批归一化

40 阅读5分钟

我们用一个非常贴近生活的“班级成绩”例子,来向初学者讲清楚 PyTorch 中 torch.nn.BatchNorm(批归一化)的原理、应用场景和使用方法。


🎯 一、什么是 BatchNorm?一句话总结:

BatchNorm 就是在神经网络训练过程中,对每一批(batch)数据的每个特征维度,做“标准化”(减均值、除标准差),然后再加上可学习的缩放和偏移参数,让数据分布更稳定,训练更快更稳。


📚 二、用“班级成绩”打比方 —— 形象理解 BatchNorm

假设你是一个班主任,教数学、语文、英语三门课。每次考试后,你收到一个“批次”(batch)的成绩单,比如有 32 个学生(batch_size=32),每人三门课成绩。

你发现:

  • 数学平均分 90,标准差 5 → 成绩很集中
  • 语文平均分 60,标准差 20 → 成绩很分散
  • 英语平均分 75,标准差 10

👉 问题来了:三门课分数尺度和分布不同,直接输入给“校长AI模型”做学生综合评估,会导致模型训练困难(梯度爆炸/消失、收敛慢)。

✅ 解决方案:对每门课的成绩,做“班级内标准化”

也就是:

对数学成绩:每人减去本班数学平均分,再除以数学标准差
对语文成绩:每人减去本班语文平均分,再除以语文标准差
英语同理...

这样,每门课成绩都变成:均值≈0,标准差≈1 —— 这就是“批归一化”的核心!

但校长说:“标准化后成绩都集中在0附近,不好看,而且可能丢失原始信息。”

于是你加了两个“可学习的参数”:

  • γ(gamma)→ 缩放因子(比如想让数学重要,就放大)
  • β(beta)→ 偏移量(比如想整体抬高语文分数)

最终成绩 = γ × 标准化成绩 + β

🧠 这两个参数 γ 和 β 是网络自己学出来的!目的是在“稳定分布”的基础上,保留表达能力。


⚙️ 三、BatchNorm 的数学公式(配合例子)

对一个 batch 中某个特征维度(比如“数学成绩”):

  1. 计算 batch 均值:
    μ = mean(x) → 本班数学平均分

  2. 计算 batch 方差:
    σ² = var(x) → 本班数学方差

  3. 标准化:
    x̂ = (x - μ) / √(σ² + ε) → 标准化数学成绩(ε 是小常数防除零)

  4. 缩放和偏移:
    y = γ * x̂ + β → 最终输出成绩(γ, β 是可训练参数)


🌍 四、为什么用 BatchNorm?—— 应用场景

  1. 加速训练:让每一层输入分布稳定(解决 Internal Covariate Shift)
  2. 提升模型稳定性:减少对初始化和学习率的敏感
  3. 有一定正则化效果:因为每个 batch 统计量略有不同,相当于加了噪声
  4. 常用于 CNN、全连接网络,尤其是深层网络中

📌 举个实际场景:

你训练一个识别手写数字的 CNN,用了 10 层卷积。不用 BatchNorm,训练慢、容易崩溃;用了之后,收敛快、准确率高。


💻 五、PyTorch 中怎么用?—— 代码示例(配合班级成绩)

我们用一个“模拟成绩处理网络”来演示:

import torch
import torch.nn as nn

# 假设:batch_size=4个学生,3门课成绩(特征维度=3)
# 数据形状:[batch_size, num_features] = [4, 3]
scores = torch.tensor([
    [85, 60, 70],  # 学生1:数学、语文、英语
    [90, 80, 75],
    [95, 50, 80],
    [88, 70, 85]
], dtype=torch.float32)

print("原始成绩:")
print(scores)

# 创建 BatchNorm1d,特征数=3(三门课)
bn = nn.BatchNorm1d(num_features=3, affine=True)  # affine=True 表示启用 γ 和 β

# 训练模式下,BatchNorm 会用当前 batch 的均值和方差,并更新统计量
bn.train()
normalized_scores = bn(scores)

print("\n归一化后的成绩:")
print(normalized_scores)

# 查看学出来的 γ 和 β(初始值 γ=1, β=0,训练后会变化)
print("\nγ (缩放):", bn.weight.data)
print("β (偏移):", bn.bias.data)

# 查看当前统计的移动平均(训练多个 batch 后才有意义)
print("\n累计均值估计:", bn.running_mean)
print("累计方差估计:", bn.running_var)

📌 输出示例(数值会因初始化略有不同):

原始成绩:
tensor([[85., 60., 70.],
        [90., 80., 75.],
        [95., 50., 80.],
        [88., 70., 85.]])

归一化后的成绩:
tensor([[-1.1839, -0.2673, -1.3363],
        [ 0.0000,  1.3363, -0.2673],
        [ 1.1839, -1.3363,  0.8018],
        [ 0.0000,  0.2673,  0.8018]], grad_fn=<AddBackward0>)

γ (缩放): tensor([1., 1., 1.], requires_grad=True)
β (偏移): tensor([0., 0., 0.], requires_grad=True)

累计均值估计: tensor([89.5000, 65.0000, 77.5000])
累计方差估计: tensor([18.7500, 125.0000, 37.5000])

🔍 说明:

  • 每一列(每门课)都被标准化成均值≈0,标准差≈1
  • γ 和 β 初始是 1 和 0,训练后会被优化器更新
  • running_meanrunning_var 是训练过程中累积的“总体估计”,用于推理(eval模式)

🔄 六、训练 vs 推理模式

  • 训练时(train()):用当前 batch 的均值/方差,同时更新 running_mean/var
  • 推理时(eval()):用累计的 running_mean/var,保证确定性
bn.eval()  # 切换到推理模式
test_scores = torch.tensor([[87, 65, 78]], dtype=torch.float32)
output = bn(test_scores)
print("推理模式输出:", output)

✅ 七、初学者常见问题

  1. BatchNorm 放在哪?
    → 通常放在卷积层或全连接层 之后,激活函数 之前(也有放之后的,看情况)

    self.conv = nn.Conv2d(...)
    self.bn = nn.BatchNorm2d(...)
    self.relu = nn.ReLU()
    # 常用顺序: conv → bn → relu
    
  2. BatchNorm 需要多少数据?
    → batch_size 太小(如=1,2)效果差,因为统计不准。建议 ≥ 8~16

  3. BatchNorm 会改变模型表达能力吗?
    → 不会!因为有 γ 和 β,理论上可以还原原始分布,只是让训练更稳定。


🎉 总结口诀(方便记忆):

“批批标准化,减均除方差,缩放加偏移,训练快如马!”


📚 举一反三:

  • BatchNorm1d:处理 [batch, features] 或 [batch, features, seq_len](如NLP)
  • BatchNorm2d:处理图像 [batch, channels, H, W]
  • BatchNorm3d:处理视频或3D数据