卷积神经网络及其经典模型(8) ResNet

87 阅读7分钟

1. ResNet网络介绍

我们知道要提升网络性能,除了更好的硬件和更大的数据集以外,最主要的办法就是增加网络的深度和宽度,而增加网络的深度和宽度带来最直接的问题就是网络参数剧增,使得模型容易过拟合以及难以训练。在VGG中,通过使用 3 * 3的卷积核替代大卷积核实现将深度提升到了19层;在GoogLeNet中,通过引入Inception结构,实现将深度提升到22层,在BN-Inception中达到30层,还有很多例子都表明,越深的网络性能越好。

但是直接简单粗暴地堆叠深度,也导致了一个臭名昭著的问题:梯度爆炸/消失,使得网络无法收敛,不过通过适当的权值初始化和Batch Normalization 可以加快网络收敛。但是,当网络收敛后,又暴露出了一个问题,就是网络退化。当网络深度变深后,准确率开始达到饱和,然后迅速退化,并且这种现象不是由梯度消失和过拟合造成的。在论文中给出了如下的一张图,分别表示20层模型和56层模型在测试集和训练集的误差。从图中我们也可以看出,迭代到后面时,误差也是在波动的,说明梯度并没有消失,参数还在更新,并且无论在训练集测试集上,56层的错误率都远高于20层模型的,这也说明不是由过拟合导致的。(过拟合的现象是,在训练集上的误差小,但是在测试集上误差很大)

image-20220411140631434

要解决网络退化问题,作者首先提出了一个解决办法,就是先构造一个浅模型,然后将得到浅模型的输出和原输入汇总(记为A)作为下一个浅模型的输入,下一个浅模型将输入A计算得到输出与输入A本身汇总成为下一个浅模型的输入,依次下去,如下图所示。这样做的好处是,将输入引入进来,可以使得模型训练后的效果至少不会比初始的输出差,最多也就是和输入相同,解决了网络退化的问题。进而作者提出了残差模型,具体关于残差 模块为什么可以解决网络退化问题,在下一节会进行介绍。

image-20220411152400376

值得一提的是,ResNet来自的论文 Deep Residual Learning for Image Recognition 作者全部来自中国,其中第一作者何恺明是03年广东高考状元,单单这篇ResNet论文的Google引用达到了10万多,著名的参数初始化方法Kaiming初始化方法就是由他提出的,在PyTorch中也进行了实现(torch.nn.init.kaiming_uniform等),妥妥的超级大牛。

2. 残差模块

残差模块如下图所示,其中左边用于小型的ResNet网络,右边的用于大型的ResNet网络,之所以要增加1 * 1的卷积核,是为了通过升维和降维减少参数量和计算量。

image-20220411203614071

残差模块解决网络退化的机理:

  • 深层梯度回传顺畅: 恒等映射这一路的梯度为1,把深层梯度注入底层,防止梯度消失。

  • 传统线性结构网络难以拟合恒等映射:  什么都不做有时很重要,无论什么样的网络模型都很难做到输入和输出相同,而残差模块可以让模型自行选择是否要更新,同时弥补了高度非线性造成的不可逆的信息损失。

  • ResNet反向传播传回的梯度相关性好:  随着网络的加深,相邻像素反向传播来的梯度相关性就越来越低,最后基本无关,变成随机扰动。所谓的相关性,指的是,在图片中,一个像素周围的像素肯定是有一定联系的,比如耳朵上的像素周围也大概率是耳朵。残差模块的引入使得梯度相关性的衰减大幅减少,保持了梯度的相关性。

  • ResNet相当于几个浅层网络的集成:  如下图左所示,三个串联的残差模块可以看成多个浅层神经网络的组合,也就是说,对于 n 个残差模块,有  个潜在的路径(这和Dropout的原理很像)。并且在测试阶段,去掉某几个残差块,几乎不影响性能,如下图右所示,去掉  不会对模型造成很严重的影响,只是和  相关的路无法通过了而已。

    image-20220411195429802

 

3. 模型结构

image-20220411203812097

图片来自论文,介绍了不同深度的ResNet结构。下 图是34层的ResNet结构,对照表格分析可以得知,输入图片先经过一个 7 * 7的卷积层,然后是一个 3 * 3的最大池化层,接着在表格中分成了4个卷积结构,第一个卷积结构中是三个残差结构,该残差结构由两个 3 * 3 的卷积层组成;第二个卷积结构由 4 个残差结构组成,该残差结构由两个 3 * 3 的卷积核组成,依次类推,最后经过一个平均池化层,一个全连接层,一个softmax层输出结果。

4. 代码实现

论文中一共介绍了两种残差模块,区别如下所示,一种是基础残差模块,用 BasicBlock 类实现,另外一个是Bottleneck模块,用 Bottleneck 类实现,在表格中我们可以发现,对于不同的深度,虽然采用的残差模块结构都是以下二者之一。但是残差模块的输入输出维度不一定是相同的,具体来说,有些残差模块输入和输出维度相同,而有些输出时输入的两倍,这是为了数据的升维。所以,在实现残差模块时还需要一个标记,表示该模块是否进行升维。残差模块实现后,ResNet模型就是残差结构的堆叠。

image-20220411203614071

import torch
 
 
class BasicBlock(nn.Module): # 普通残差结构
    expansion = 1
 
    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample
 
    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)
 
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
 
        out = self.conv2(out)
        out = self.bn2(out)
 
        out += identity
        out = self.relu(out)
 
        return out
 
 
class Bottleneck(nn.Module):
    """
    注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。
    但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,
    这么做的好处是能够在top1上提升大概0.5%的准确率。
    可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch
    """
    expansion = 4
 
    def __init__(self, in_channel, out_channel, stride=1, downsample=None,
                 groups=1, width_per_group=64):
        super(Bottleneck, self).__init__()
 
        width = int(out_channel * (width_per_group / 64.)) * groups
 
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
                               kernel_size=1, stride=1, bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
 
    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)
 
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
 
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
 
        out = self.conv3(out)
        out = self.bn3(out)
 
        out += identity
        out = self.relu(out)
 
        return out
 
 
class ResNet(nn.Module):
 
    def __init__(self,
                 block,  # 残差结构的类别
                 blocks_num:list, # 列表,每层残差结构的个数
                 num_classes=1000,
                 include_top=True,
                 groups=1,
                 width_per_group=64):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64
 
        self.groups = groups
        self.width_per_group = width_per_group
 
        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)
 
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
 
    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))
 
        layers = []
        layers.append(block(self.in_channel,
                            channel,
                            downsample=downsample,
                            stride=stride,
                            groups=self.groups,
                            width_per_group=self.width_per_group))
        self.in_channel = channel * block.expansion
 
        for _ in range(1, block_num):
            layers.append(block(self.in_channel,
                                channel,
                                groups=self.groups,
                                width_per_group=self.width_per_group))
 
        return nn.Sequential(*layers)
 
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
 
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
 
        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)
 
        return x
 
 
def resnet34(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)
 
 
def resnet50(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet50-19c8e357.pth
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)
 
 
def resnet101(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)
 
 
def resnext50_32x4d(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth
    groups = 32
    width_per_group = 4
    return ResNet(Bottleneck, [3, 4, 6, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)
 
 
def resnext101_32x8d(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth
    groups = 32
    width_per_group = 8
    return ResNet(Bottleneck, [3, 4, 23, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)