持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第6天,点击查看活动详情
对于一个优秀的网络,在resnet提出之前,并非越深越好,当一个网络结构增加到一定深度之后,由于特征不断被压缩,导致浅层次的特征在传递到深得网络层之后,特征被无限缩小,使得检测的效果非常差。而resnet通过将浅层次的特征通过跳跃连接的方式,直接传递到更深层次的网络,而不经过中间层的不断压缩,这样即使网络再深,依然可以获取到最原始的特征。
下面通过一幅图来直观的解释一下resnet的核心思想。
图中左边展示的为resnet的残差块,之所以称为块,是因为网络结构是由很多类似的这样的块堆叠起来的,注意图中红色框部分,便是浅层次特征与深层次特征结合的部分。
整个网络的结构大概就是通过加深网络,给深层次网络传递浅层次特征使其不至于丢失特征。这样的网络结构既保留了浅层次的细节特征,又保留了深层次的语义特征,目前使用非常广泛,后面博客中介绍的unet网络结构,也是使用了残差连接的方式。
下面,通过代码来实现上面的网络结构
import torch
from torch import nn
import torch.nn.functional as F
class ResBlock(nn.Module):
'''
残差模块
'''
def __init__(self, n_chans):
super().__init__()
'''
首先定义几个网络的结构,通过继承nn.Module进行实现
'''
self.conv = nn.Conv2d(n_chans, n_chans, kernel_size=3, padding=1, bias=False)
self.batch_norm = nn.BatchNorm2d(n_chans)
torch.nn.init.kaiming_normal_(self.conv.weight, nonlinearity='relu') # 参数初始化
torch.nn.init.constant_(self.batch_norm.weight, 0.5)
torch.nn.init.zeros_(self.batch_norm.bias)
def forward(self, x):
'''
实现网络的前向传播,也就是将网络结构串起来
:param x: 网络的输入
:return:
'''
out = self.conv(x)
out = self.batch_norm(out)
out = F.relu(out)
return out + x
class NetResDepp(nn.Module):
def __init__(self, n_chans1=32, num_blocks=100):
'''
定义resnet整体结构
:param n_chans1:
:param num_blocks:
'''
super().__init__()
self.n_chans1 = n_chans1
self.num_blocks = num_blocks
self.conv = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)
# 此处是直接使用了容器,将残差块通过复制100次的形式,实现一个深度网络,而不是真的将整个残差块的代码复制100次
self.res_blocks = nn.Sequential(*[ResBlock(n_chans=n_chans1) for _ in range(10)])
self.fc1 = nn.Linear(8 * 8 * n_chans1, 32)
self.fc2 = nn.Linear(32, 2)
def forward(self, x):
out = F.relu(self.conv(x))
out = F.max_pool2d(out, 2)
out = self.resblocks(out)
out = F.max_pool2d(out, 2)
out = out.view(-1, 8 * 8 * self.n_chans1)
out = self.fc1(out)
out = self.fc2(out)
return out
model = NetResDepp(16)
print(model)
对于上述网络结构,我们可以在调整输入的维度之后进行随意的嵌入,即可发挥其作用,下面我们实例化上述网络结构,打印一下网络的内部构成。
model = NetResDepp(16)
print(model)
通过了解残差思想的原理,后面的许多网络结构都可以借助这样的思想,可以起到保留网络特征的能力。