pytorch学习笔记(四)

80 阅读2分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第6天,点击查看活动详情

对于一个优秀的网络,在resnet提出之前,并非越深越好,当一个网络结构增加到一定深度之后,由于特征不断被压缩,导致浅层次的特征在传递到深得网络层之后,特征被无限缩小,使得检测的效果非常差。而resnet通过将浅层次的特征通过跳跃连接的方式,直接传递到更深层次的网络,而不经过中间层的不断压缩,这样即使网络再深,依然可以获取到最原始的特征。

下面通过一幅图来直观的解释一下resnet的核心思想。

image.png

图中左边展示的为resnet的残差块,之所以称为块,是因为网络结构是由很多类似的这样的块堆叠起来的,注意图中红色框部分,便是浅层次特征与深层次特征结合的部分。

整个网络的结构大概就是通过加深网络,给深层次网络传递浅层次特征使其不至于丢失特征。这样的网络结构既保留了浅层次的细节特征,又保留了深层次的语义特征,目前使用非常广泛,后面博客中介绍的unet网络结构,也是使用了残差连接的方式。

下面,通过代码来实现上面的网络结构

import torch
from torch import nn
import torch.nn.functional as F


class ResBlock(nn.Module):
    '''
    残差模块
    '''
    def __init__(self, n_chans):
        super().__init__()
        '''
        首先定义几个网络的结构,通过继承nn.Module进行实现
        '''
        self.conv = nn.Conv2d(n_chans, n_chans, kernel_size=3, padding=1, bias=False)
        self.batch_norm = nn.BatchNorm2d(n_chans)
        torch.nn.init.kaiming_normal_(self.conv.weight, nonlinearity='relu')  # 参数初始化
        torch.nn.init.constant_(self.batch_norm.weight, 0.5)
        torch.nn.init.zeros_(self.batch_norm.bias)

    def forward(self, x):
        '''
        实现网络的前向传播,也就是将网络结构串起来
        :param x: 网络的输入
        :return:
        '''
        out = self.conv(x)
        out = self.batch_norm(out)
        out = F.relu(out)
        return out + x


class NetResDepp(nn.Module):
    def __init__(self, n_chans1=32, num_blocks=100):
        '''
        定义resnet整体结构
        :param n_chans1:
        :param num_blocks:
        '''
        super().__init__()
        self.n_chans1 = n_chans1
        self.num_blocks = num_blocks
        self.conv = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)
        # 此处是直接使用了容器,将残差块通过复制100次的形式,实现一个深度网络,而不是真的将整个残差块的代码复制100次
        self.res_blocks = nn.Sequential(*[ResBlock(n_chans=n_chans1) for _ in range(10)])
        self.fc1 = nn.Linear(8 * 8 * n_chans1, 32)
        self.fc2 = nn.Linear(32, 2)

    def forward(self, x):
        out = F.relu(self.conv(x))
        out = F.max_pool2d(out, 2)
        out = self.resblocks(out)
        out = F.max_pool2d(out, 2)
        out = out.view(-1, 8 * 8 * self.n_chans1)
        out = self.fc1(out)
        out = self.fc2(out)
        return out

model = NetResDepp(16)
print(model)

对于上述网络结构,我们可以在调整输入的维度之后进行随意的嵌入,即可发挥其作用,下面我们实例化上述网络结构,打印一下网络的内部构成。

model = NetResDepp(16)
print(model)

image.png

通过了解残差思想的原理,后面的许多网络结构都可以借助这样的思想,可以起到保留网络特征的能力。