VoxelMorph中网络结构(特征提取和空间变换)源码剖析

625 阅读4分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第2天,点击查看活动详情

VoxelMorph主要是对医学图像进行配准而设计,论文最初使用的是脑部的医学图像进行举例说明,论文中做了大量的比较实验来突出作者方法的优点。首先,论文提到了两种配准的方式,第一种是无监督的配准方法,另一种是有监督的配准方法,针对两者的优劣,作者给出了不同的解释。

首先看一下论文提出的网络结构

image.png

其中m作为moving,f作为fixed,网络的目的是将moving通过一系列的变换将其对齐到fixed,对于图中的gθ(f,m),是一个特征的提取网络,原文中使用了Unet进行特征的提取,对应到实验源码如下。

class U_Network(nn.Module):
    def __init__(self, dim, enc_nf, dec_nf, bn=None, full_size=True):
        super(U_Network, self).__init__()
        self.bn = bn
        self.dim = dim
        self.enc_nf = enc_nf
        self.full_size = full_size
        self.vm2 = len(dec_nf) == 7
        # Encoder functions
        self.enc = nn.ModuleList()
        for i in range(len(enc_nf)):
            prev_nf = 2 if i == 0 else enc_nf[i - 1]
            self.enc.append(self.conv_block(dim, prev_nf, enc_nf[i], 4, 2, batchnorm=bn))
        # Decoder functions
        self.dec = nn.ModuleList()
        self.dec.append(self.conv_block(dim, enc_nf[-1], dec_nf[0], batchnorm=bn))  # 1
        self.dec.append(self.conv_block(dim, dec_nf[0] * 2, dec_nf[1], batchnorm=bn))  # 2
        self.dec.append(self.conv_block(dim, dec_nf[1] * 2, dec_nf[2], batchnorm=bn))  # 3
        self.dec.append(self.conv_block(dim, dec_nf[2] + enc_nf[0], dec_nf[3], batchnorm=bn))  # 4
        self.dec.append(self.conv_block(dim, dec_nf[3], dec_nf[4], batchnorm=bn))  # 5

        if self.full_size:
            self.dec.append(self.conv_block(dim, dec_nf[4] + 2, dec_nf[5], batchnorm=bn))
        if self.vm2:
            self.vm2_conv = self.conv_block(dim, dec_nf[5], dec_nf[6], batchnorm=bn)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

        # One conv to get the flow field
        conv_fn = getattr(nn, 'Conv%dd' % dim)
        self.flow = conv_fn(dec_nf[-1], dim, kernel_size=3, padding=1)
        # Make flow weights + bias small. Not sure this is necessary.
        nd = Normal(0, 1e-5)
        self.flow.weight = nn.Parameter(nd.sample(self.flow.weight.shape))
        self.flow.bias = nn.Parameter(torch.zeros(self.flow.bias.shape))
        self.batch_norm = getattr(nn, "BatchNorm{0}d".format(dim))(3)

    def conv_block(self, dim, in_channels, out_channels, kernel_size=3, stride=1, padding=1, batchnorm=False):
        conv_fn = getattr(nn, "Conv{0}d".format(dim))
        bn_fn = getattr(nn, "BatchNorm{0}d".format(dim))
        if batchnorm:
            layer = nn.Sequential(
                conv_fn(in_channels, out_channels, kernel_size, stride=stride, padding=padding),# nn.Conv3d()
                bn_fn(out_channels), # nn.BatchNorm3d()
                nn.LeakyReLU(0.2))
        else:
            layer = nn.Sequential(
                conv_fn(in_channels, out_channels, kernel_size, stride=stride, padding=padding),
                nn.LeakyReLU(0.2))
        return layer

    def forward(self, src, tgt):
        # src=(1,1,160,192,160)
        # tgt=(1,1,160,192,160)
        x = torch.cat([src, tgt], dim=1)
        # x=(1,2,160,192,160)
        # Get encoder activations
        x_enc = [x]
        for i, l in enumerate(self.enc):
            x = l(x_enc[-1])
            x_enc.append(x)
        # Three conv + upsample + concatenate series
        y = x_enc[-1]
        for i in range(3):
            y = self.dec[i](y)
            y = self.upsample(y)
            y = torch.cat([y, x_enc[-(i + 2)]], dim=1)
        # Two convs at full_size/2 res
        y = self.dec[3](y)
        y = self.dec[4](y)
        # Upsample to full res, concatenate and conv
        if self.full_size:
            y = self.upsample(y)
            y = torch.cat([y, x_enc[0]], dim=1)
            y = self.dec[5](y)
        # Extra conv for vm2
        if self.vm2:
            y = self.vm2_conv(y)
        flow = self.flow(y)
        if self.bn:
            flow = self.batch_norm(flow)
        return flow

image.png

如果在代码中想进一步了解各个层的构建过程可以通过断点调试的方法逐行运行代码,上图为通过断点进行调试的界面,可以看到整个uNet结构,encode和decode的结构大小。

通过Unet特征提取之后得到的就是配准域,也就是网络结构图中的Registration Field.

对于这一部分,可以看到其中有一个往下的箭头,写着Lsmooth和Lsim,这两个分别是平滑损失和相似度损失。

在输入的Moving 3D Image(m)中,箭头指向Spatial Transform,这一部分是空间变换网络,通过将特征域与m输入到该网络中可以得到一个配准成功的图像moved。

对于Spatial Tranform,论文中提到了可以是MLP或者迷你CNN,但是在源码跟踪过程中我发现作者并没有用上述两者,而是直接使用了下面的操作。

class SpatialTransformer(nn.Module):
    def __init__(self, size, mode='bilinear'):
        super(SpatialTransformer, self).__init__()
        # Create sampling grid
        vectors = [torch.arange(0, s) for s in size]
        grids = torch.meshgrid(vectors)
        grid = torch.stack(grids)  # y, x, z
        grid = torch.unsqueeze(grid, 0)  # add batch
        grid = grid.type(torch.FloatTensor)
        self.register_buffer('grid', grid)

        self.mode = mode

    def forward(self, src, flow):
        new_locs = self.grid + flow
        shape = flow.shape[2:]

        # Need to normalize grid values to [-1, 1] for resampler
        for i in range(len(shape)):
            new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)

        if len(shape) == 2:
            new_locs = new_locs.permute(0, 2, 3, 1)
            new_locs = new_locs[..., [1, 0]]
        elif len(shape) == 3:
            new_locs = new_locs.permute(0, 2, 3, 4, 1)
            new_locs = new_locs[..., [2, 1, 0]]

        return F.grid_sample(src, new_locs, mode=self.mode)

image.png

所以,对于空间变换网络,作者主要进行了下面两个步骤:

第一:网格生成器 grid generater,其目的是将定位网络预测的变换参数应用于输入的特征图。

第二:一个采样器 sampler 作为插值器来构造最终输出的扭曲图像。

最后,总结论文的贡献,论文提出的这种方法是一种通用的配准方法,不仅适用于脑部的图像,而且对于肺部的医学图像同样适用。而且论文中提出的方法比原来的存在的方法更加的快速,即使在CPU上执行,速度比原来快了很多,说明该方法对于数据的处理能力非常有效,而且其dice的评分也相对较好。下面给出一组实验数据,如下,由于数据都为三维图像,而展示贴图时只能展示二维。

image.png