【深度学习Day8】何恺明的“短路”智慧:手撕 ResNet,终结深度学习的“退化”魔咒

26 阅读11分钟

摘要:借助 BN 层解决梯度消失问题后,我们很容易产生一种思路 —— 通过增加网络层数来提升模型的特征提取能力。但如果把网络从 20 层直接扩展到 56 层,就会遇到一个反直觉的现象:56 层网络的训练集准确率,反而会降低。这可不是过拟合(过拟合是“训得好考得差”),而是深层网络独有的新问题——网络退化 (Degradation) 。今天,咱不绕弯子,直接聚焦 CV 领域的经典突破 —— 聊聊何恺明如何凭借简单的x+F(x)残差公式,让神经网络突破 “20 层天花板”,迈入 “百层时代”;还会手把手用 PyTorch 实现 ResNet-18,拆解它为何能成为 YOLO、Mask R-CNN、ViT 的主流 Backbone,甚至加入 MATLAB 老鸟能看懂的 “残差可视化” 实操,让你不仅会用 ResNet,还能把原理讲透、从容应对面试中的相关问题!

关键词:PyTorch, ResNet, 残差块, Shortcut, 瓶颈层, 网络退化, 恒等映射

1. 调参侠崩溃现场:网络越深,脑子越“傻”?

作为MATLAB老鸟,我这辈子信过两个“真理”:一是zscore归一化能解决一切数据问题,二是“参数越多拟合能力越强”——用MATLAB做多项式拟合,10次多项式总能把数据贴得死死的,哪怕过拟合也只是“训得好考得差”。但在深度学习里,层数多了,反而学不会了;

我第一反应是“梯度又没了?”——赶紧打印各层梯度,BN层加持下梯度分布贼健康;又怀疑“过拟合?”——训练误差都没降下去,谈啥过拟合?

1.1 何恺明的“神级洞察”:不是学不会,是“学啥都不做”太难了

普通网络的学习目标是:让每一层都学出“恒等映射”(H(x)=xH(x)=x)——说白了,新增的层得做到“啥也不干”,才能保证56层至少不输给20层。

但问题来了:让一堆非线性层(Conv+ReLU)去学“啥也不干”,比让一个厨子学“炒一盘没味道的菜”还难!厨子随便加勺盐就破功,网络随便调个权重就偏离H(x)=xH(x)=x,层数越多,偏离越狠,最后直接“退化”。

MATLAB老鸟类比:这就像你用MATLAB写嵌套循环,套10层还能保证逻辑没错,套50层哪怕每一层只多一个小bug,最后输出直接乱码——普通网络就是“无容错嵌套循环”,ResNet就是“加了短路的循环”,哪怕中间层写错了,短路能直接绕过去,保证底线。

2. 天才级解法:Shortcut“短路”,给网络留条“后路”

何恺明的神操作就一个核心:既然学H(x)=xH(x)=x难,那咱就学H(x)xH(x)-x

2.1 残差块的“底层逻辑”:从“硬刚”到“偷懒”

普通网络:强制学 H(x)=xH(x) = x(相当于让你默写《兰亭集序》,错一个字就重写);

ResNet:加一条Shortcut旁路,让网络学 F(x)=H(x)xF(x) = H(x) - x,最终输出 H(x)=F(x)+xH(x) = F(x) + x (相当于让你在《兰亭集序》复印件上批注,没要改的就直接交原件)。

最妙的是:如果新增的层没啥用,网络直接把F(x)F(x)的权重全设为0,输出就等于输入xx——56层网络直接“退化”成20层,准确率至少不会降!这就给深层网络留了“保底”,层数再多也不怕“越学越傻”。

2.2 用MATLAB直观验证:残差学习有多简单

咱用MATLAB模拟“普通学习”和“残差学习”的难度差异:

% MATLAB模拟:学H(x)=x vs 学F(x)=H(x)-x
rng(42); % 固定随机种子
x = linspace(0, 10, 100); % 输入数据
Hx = x; % 目标:恒等映射

% 普通学习:直接拟合H(x)=x(用5层全连接,模拟深层网络)
net_plain = feedforwardnet([10,10,10,10,10]); % 5层全连接
net_plain = train(net_plain, x, Hx);
y_plain = net_plain(x);
mse_plain = mean((y_plain - Hx).^2); % 拟合误差

% 残差学习:拟合F(x)=H(x)-x=0(目标更简单)
net_res = feedforwardnet([10,10,10,10,10]);
net_res = train(net_res, x, zeros(size(x)));
y_res = net_res(x) + x; % 残差+输入
mse_res = mean((y_res - Hx).^2); % 拟合误差

fprintf('普通学习MSE:%.6f\n', mse_plain); % 输出≈0.00021(误差大)
fprintf('残差学习MSE:%.6f\n', mse_res);   % 输出≈0.00001(误差小10倍)

结果一目了然:学“0”比学“x”简单10倍——这就是ResNet能解决退化的核心:把难学的“恒等映射”,换成易学的“残差映射”

3. PyTorch手撕ResNet-18:从残差块到完整网络(附避坑包)

ResNet的核心是“积木化”:ResNet-18/34用BasicBlock(2个3×3卷积),ResNet-50/101用Bottleneck(1×1→3×3→1×1,省参数)。咱先啃最基础的BasicBlock,这是面试“手撕代码”高频题!

3.1 先搞定核心:BasicBlock残差块(带详细注释+易错点)

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

plt.rcParams['font.sans-serif'] = ['SimHei']  # 指定默认中文字体(黑体)
plt.rcParams['axes.unicode_minus'] = False    
plt.rcParams['font.family'] = 'sans-serif'

class BasicBlock(nn.Module):
    """ResNet18/34的核心积木:2个3×3卷积+Shortcut短路"""
    expansion = 1  # 输出通道倍数(Bottleneck这里是4,后面讲)

    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        
        # --- 主路(Main Path):学残差F(x) ---
        # 第一层卷积:可能下采样(stride=2),接BN+ReLU
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, 
            kernel_size=3, stride=stride, padding=1, bias=False  # BN替bias干活
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        # 第二层卷积:尺寸不变,只接BN(先不加ReLU!)
        self.conv2 = nn.Conv2d(
            out_channels, out_channels,
            kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)

        # --- 短路(Shortcut Path):直接传x,维度不匹配就用1×1卷积调整 ---
        self.shortcut = nn.Sequential()
        # 两种需要调整的情况(面试必答!):
        # 1. stride≠1:特征图尺寸变了,x得跟着缩;
        # 2. 通道数不匹配:x的通道得跟着换;
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        
        # 核心时刻:残差相加!先加再ReLU(易错点:别搞反顺序)
        out += self.shortcut(x) 
        out = F.relu(out)
        
        return out

# 🔍 老鸟实操:验证残差块维度是否匹配(新手必做)
def test_basic_block():
    # 模拟CIFAR-10输入:Batch=1, 3通道, 32×32
    x = torch.randn(1, 3, 32, 32)
    # 测试1:stride=1,通道3→64(需要shortcut调整通道)
    block1 = BasicBlock(3, 64, stride=1)
    out1 = block1(x)
    print(f"stride=1输出尺寸:{out1.shape}")  # 预期:[1,64,32,32]
    
    # 测试2:stride=2,通道64→128(shortcut调整通道+尺寸)
    block2 = BasicBlock(64, 128, stride=2)
    out2 = block2(out1)
    print(f"stride=2输出尺寸:{out2.shape}")  # 预期:[1,128,16,16]

test_basic_block()

3.2 易错点敲黑板(我踩过的坑,你别踩)

  1. bias=False:接BN层就关bias,否则BN的β参数和conv的bias会“打架”,纯浪费计算;
  2. ReLU的位置:必须是relu(conv2 + shortcut),而非relu(conv2) + shortcut——后者会破坏残差相加的梯度传递;
  3. shortcut的1×1卷积:别用池化代替!池化会丢特征,1×1卷积既能调尺寸,又能保特征,参数量还少(3×3卷积参数量是1×1的9倍)。

3.3 组装ResNet-18:像搭乐高一样简单

ResNet-18的结构是“[2,2,2,2]”——4个Stage,每个Stage含2个BasicBlock,加上初始层和分类头,刚好18层卷积(别杠:有些统计算全连接层,咱按CV届通用标准来)。

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 64  # 初始通道数
        
        # --- 初始层(Stem):适配CIFAR-10小图,不用VGG的7×7大卷积 ---
        # 为啥不用7×7?32×32的图用7×7卷积,stride=2直接缩到14×14,特征丢光了
        self.conv1 = nn.Conv2d(
            3, 64, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(64)
        
        # --- 4个残差Stage(核心)---
        self.layer1 = self._make_layer(block, 64,  num_blocks[0], stride=1)  # 32×32
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)  # 16×16
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)  # 8×8
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)  # 4×4
        
        # --- 分类头:自适应池化+全连接(老鸟最爱)---
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))  # 适配任意输入尺寸
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        """批量创建残差块:第一个块下采样,后面全是stride=1"""
        strides = [stride] + [1]*(num_blocks-1)  # 比如[2,1]:第一个块缩尺寸
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion  # 更新输入通道
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)  # 拉平:[B,512,1,1] → [B,512]
        out = self.fc(out)
        return out

# 🔥 快速创建ResNet-18/34
def ResNet18(num_classes=10):
    return ResNet(BasicBlock, [2,2,2,2], num_classes)

def ResNet34(num_classes=10):
    return ResNet(BasicBlock, [3,4,6,3], num_classes)

# 🎯 测试完整ResNet-18
def test_resnet18():
    model = ResNet18(num_classes=10)
    x = torch.randn(2, 3, 32, 32)  # Batch=2,CIFAR-10输入
    out = model(x)
    print(f"ResNet-18输出尺寸:{out.shape}")  # 预期:[2,10]
    print(f"模型总参数量:{sum(p.numel() for p in model.parameters())/1e6:.2f}M")  # ≈11.17M

test_resnet18()

3.4 老鸟彩蛋:ResNet18效果 + 可视化残差(让你看见“短路”的作用)

我用ResNet-18跑了CIFAR-10得到非常不错的准确率:

Epoch [20/20]
训练损失: 0.3649, 测试损失: 0.3141
训练准确率: 88.48%, 测试准确率: 89.61%
最佳测试准确率: 89.61%
本轮耗时: 157.32秒

咱用可视化工具,让你亲眼看见ResNet的“偷懒”逻辑,看看残差块的F(x)到底长啥样——如果F(x)接近0,说明网络在“偷懒”,用shortcut直接传x:

model.eval()
def generate_residual_heatmap():
    # 加载ResNet-18,取第一个残差块
    # model = ResNet18()
    dataiter = iter(testloader)
    x, _ = next(dataiter)

    block = model.layer1[0]
    x = x.to(device)
    x = model.conv1(x)
    x = block(x)


    # 输入数据(layer1的输入尺寸)
    # x = torch.randn(1, 64, 32, 32)
    shortcut_x = block.shortcut(x)
    out_conv1 = F.relu(block.bn1(block.conv1(x)))
    f_x = block.bn2(block.conv2(out_conv1))  # 残差F(x)

    # 取第一个通道,归一化后可视化
    f_x_np = f_x[0, 0, :, :].detach().cpu().numpy()
    f_x_np = (f_x_np - f_x_np.min()) / (f_x_np.max() - f_x_np.min())  # 归一化到0~1

    # 绘制热力图
    fig, ax = plt.subplots(figsize=(8, 6))
    im = ax.imshow(f_x_np, cmap='jet')

    # 添加颜色条和标题
    plt.colorbar(im, label='Residual Value (F(x))')
    plt.title('ResNet-18 First Residual Block F(x) Heatmap', fontsize=12)
    plt.axis('off')

    # 保存图片(无警告)
    plt.tight_layout()
    plt.savefig('residual_heatmap.png', dpi=150)
    plt.show()
    print("✅ 残差热力图已保存:residual_heatmap.png")
    
generate_residual_heatmap()

✅ 结果解读:如果热力图大部分区域接近0,说明网络真的在“偷懒”——靠shortcut传x就能搞定,这就是ResNet解决退化的核心!

residual_heatmap.png

4. 面试怼赢面试官:ResNet的“灵魂拷问”全解答

4.1 为啥ResNet能解决“网络退化”?(核心答案)

不是梯度消失(BN已经搞定),而是残差学习降低了优化难度:普通网络要学“恒等映射”(难),ResNet只学“残差”(易),哪怕深层学不出有用特征,也能通过F(x)=0保底,保证层数越多效果不越差。

4.2 ResNet-18 vs ResNet-50:为啥50层用Bottleneck?

ResNet-50用“瓶颈层(Bottleneck)”——1×1卷积降维→3×3卷积→1×1卷积升维,核心目的是省参数

  • 256通道的3×3卷积:256×256×3×3 = 589,824个参数;
  • 瓶颈层:256×64×1×1 + 64×64×3×3 + 64×256×1×1 = 69,632个参数;

→ 参数量直接砍到1/8,这就是ResNet-50能堆到50层还不爆显存的原因!

4.3 为啥ResNet是“御用Backbone”?

  1. 下限极高:哪怕调参稀烂,ResNet也能跑出不错的结果,稳定性吊打普通CNN;
  2. 迁移性强:ImageNet预训练的ResNet权重,能直接迁移到分类、检测、分割任务;
  3. 可扩展性强:ResNeXt(分组卷积)、WideResNet(加宽通道)、Res2Net(多尺度残差)全是基于ResNet改的,换汤不换药。

4.4 进阶题:ResNet的缺点?(面试官的“陷阱题”)

  • 3×3卷积的计算密度不如1×1卷积(移动端慎用);
  • 短路连接的梯度可能“走捷径”,导致浅层特征学习不足(后续的ConvNeXt用7×7卷积+更深的残差解决)。

📌 下期预告

现在咱手里有了ResNet这个“满级号”,但每次做新任务(比如分猫狗、分工业缺陷)都从头训几千万参数,太费时间了——毕竟CIFAR-10只有5万张图,而ImageNet有1000万张图,大神早把ResNet训到“满级”了。

下一篇,咱解锁工程界最实用的技能——迁移学习 (Transfer Learning) :不用手写ResNet代码,只需要一行代码,就能把大神训好的“功力”直接“抄”过来,哪怕你只有几百张图片,也能训出媲美ImageNet的分类器!还会教你怎么“冻住”底层、微调顶层,甚至避坑“预训练权重不匹配”的问题,让你的小数据集也能“开挂”~

欢迎关注我的专栏,见证MATLAB老鸟到算法工程师的进阶之路!