PyTorch 版本的 UNet 模型实现的代码

1,097 阅读4分钟

周末了,通过写一个 UNet 来平静一下自己,最近看了很多论文,积累一些基础知识,不过感觉仅是加强理论还不够,理论还需要联系实践,而且也不能好高鹭远,一切先从基础开始,先写一个 UNet 吧。

如果还不了解什么是语义分割可以看一看 来详解一下计算机视觉中的语义分割任务 (1) - 掘金 (juejin.cn)

unet.png

这张图看出了 UNet 结构还是比较简单,现在网络变得越来越复杂,越来越难以实现。

引入必要的库

这里我们只会用到 torch ,暂时还不会涉及其他的库。

import torch
import torch.nn as nn

定义基础 CBL 结构

定义卷积块,这个卷积块就是标准配置、卷积加 BatchNormalization 再加上激活层,没有什么特殊的结构。

class ConvBlk(nn.Module):
  
  def __init__(self,in_channels,out_channels):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1)
    self.bn1 = nn.BatchNorm2d(out_channels)
    self.relu = nn.ReLU()
  
  def forward(self,x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    print(x.shape)

接下来我们去测试一下 ConvBlk 定义好的模块,在 pytorch 中的 nn.Module 这个模块类有点类似基础设施,我们定义层也好、块也好通常都会是继承于该基类,这个类有点类似做了一些基础工作,为我们构建大型网络提供一个可能性,基于 Module 我们就很容易地构建出复杂大型网络。

input = torch.randn((2,3,512,512))
convBlk = ConvBlk(3,64)
convBlk(input)

输入是一个 RGB 3 通道的图像,输出为 64 通道的特征图,在 CBL 模块并不会进行下采样。

torch.Size([2, 64, 512, 512])

接下来继续在 ConvBlk 堆叠一套 CBL 层,其实我们也可以将 ConvBlk 修改为 CBL 更为贴切。

接下里,我继续回到图,会发现在 paper 中的编码器中每一块都是堆叠两个卷积层。而且并不会改变特征图的大小。

class ConvBlk(nn.Module):
  
  def __init__(self,in_channels,out_channels):
    super().__init__()
    
    self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1)
    self.bn1 = nn.BatchNorm2d(out_channels)

    self.conv2 = nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1)
    self.bn2 = nn.BatchNorm2d(out_channels)

    self.relu = nn.ReLU()
  
  def forward(self,x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)

    x = self.conv2(x)
    x = self.bn2(x)
    x = self.relu(x)

    return x

接下里就是编码器结构,我们只需要在编码器块中添加一个下采样,这里下采用了 maxpooling 来进行下采样。

class encoderBlk(nn.Module):
  def __init__(self,in_channels,out_channels):
    super().__init__()

    self.cbl = CBL(in_channels,out_channels)
    self.pool = nn.MaxPool2d((2,2))

  def forward(self,x):
    x = self.cbl(x)
    output = self.pool(x)

    return x, output

在实现了之后,我们来测试一下模块,查看其输出为两个部分,一个部分

input = torch.randn((2,3,512,512))
encoder_Blk = encoderBlk(3,64)
x,output = encoder_Blk(input)
print(x.shape)
print(output.shape)
torch.Size([2, 64, 512, 512]) torch.Size([2, 64, 256, 256])

解码器

在编码器中是一个上采样过程,这里上采样采用转置卷积,在 pytorch 框架提供转置卷积,有关转置卷积在

来详解一下计算机视觉中的语义分割任务 (5)—转置卷积 - 掘金 (juejin.cn)

来详解一下计算机视觉中的语义分割任务 (6)—转置卷积 - 掘金 (juejin.cn)

class decoderBlk(nn.Module):
  def __init__(self,in_channels,out_channels):
    super().__init__()

    self.up = nn.ConvTranspose2d(in_channels,out_channels,kernel_size=2,stride=2,padding=0)

  def forward(self,x):
    x = self.up(x)
    return x

这里转置卷积核大小给 2 而且步长给 2 然后定义 batch 大小为 2、通道为 64 特征图的大小为 256 x 256

input = torch.randn((2,64,256,256))
decoder_Blk = decoderBlk(64,32)
x = decoder_Blk(input)
print(x.shape)

输出特征图通道数减半,然后特征图大小增加一倍,也就是不断将通道上的信息还原到空间上。

torch.Size([2, 32, 512, 512])
class decoderBlk(nn.Module):
  def __init__(self,in_channels,out_channels):
    super().__init__()

    self.up = nn.ConvTranspose2d(in_channels,out_channels,kernel_size=2,stride=2,padding=0)
    self.conv = CBL(out_channels +out_channels,out_channels)

  def forward(self,x,skip):
    x = self.up(x)
    x = torch.cat([x,skip],axis=1)
    x = self.conv(x)
    return x

构建模型

定义 buildUNet 类来将之前定义好的编码器和解码器组织一起,本次分享基本

class buildUNet(nn.Module):

  def __init__(self):
      super().__init__()

      """
      encoder
      """

      self.encoder_1 = encoderBlk(3,64)
      self.encoder_2 = encoderBlk(64,128)
      self.encoder_3 = encoderBlk(128,256)
      self.encoder_4 = encoderBlk(256,512)

      """
      bottleNeck
      """

      self.bottle_neck = CBL(512,1024)

  def forward(self,x):
    output_1, feature_1 = self.encoder_1(x)
    output_2, feature_2 = self.encoder_2(feature_1)
    output_3, feature_3 = self.encoder_3(feature_2)
    output_4, feature_4 = self.encoder_4(feature_3)

    print(output_1.shape,feature_1.shape)
    print(output_2.shape,feature_2.shape)
    print(output_3.shape,feature_3.shape)
    print(output_4.shape,feature_4.shape)

inputs = torch.randn((2,3,512,512))
build_net = buildUNet()
build_net(inputs)
torch.Size([2, 64, 512, 512]) torch.Size([2, 64, 256, 256]) 
torch.Size([2, 128, 256, 256]) torch.Size([2, 128, 128, 128]) 
torch.Size([2, 256, 128, 128]) torch.Size([2, 256, 64, 64]) 
torch.Size([2, 512, 64, 64]) torch.Size([2, 512, 32, 32])

构建网络,输入编码器是 RGB 3 通道的图像,然后经过第一个编码器通道数升为 64 通道,同时尺寸进行一次减半为 256 随后继续继续进行下采样,通道数从 64 升到 128 特征图大小继续减半为 128 通过编码器通道为 32, 然后经过的瓶颈层,也就是卷积将通道数升到 1024 ,然后就是经过一系列转置卷积进行上采样,随后还是需要经过一个 1x1 卷积将通道数改变为预测类别相同大小。

class buildUNet(nn.Module):

  def __init__(self):
    super().__init__()

    """
    encoder
    """

    self.encoder_1 = encoderBlk(3,64)
    self.encoder_2 = encoderBlk(64,128)
    self.encoder_3 = encoderBlk(128,256)
    self.encoder_4 = encoderBlk(256,512)

    """
    bottleNeck
    """

    self.bottle_neck = CBL(512,1024)

    """
    decoder
    """
    self.decoder_1 = decoderBlk(1024,512)
    self.decoder_2 = decoderBlk(512,256)
    self.decoder_3 = decoderBlk(256,128)
    self.decoder_4 = decoderBlk(128,64)

  def forward(self,x):
    output_1, feature_1 = self.encoder_1(x)
    output_2, feature_2 = self.encoder_2(feature_1)
    output_3, feature_3 = self.encoder_3(feature_2)
    output_4, feature_4 = self.encoder_4(feature_3)

    output_bottle_neck = self.bottle_neck(feature_4)

    print(output_bottle_neck.shape)
    print(output_4.shape)

    output_5 = self.decoder_1(output_bottle_neck,output_4)
    print(output_5.shape)
    output_6 = self.decoder_2(output_5,output_3)
    print(output_6.shape)
    output_7 = self.decoder_3(output_6,output_2)
    print(output_7.shape)
    output_8 = self.decoder_4(output_7,output_1)

    print(output_8.shape)