5分钟速成半监督医学图像分割

230 阅读8分钟

5分钟速成半监督医学图像分割

本文所涉及所有资源均在传知代码平台可获取

概述

这里我将介绍一篇MICCAI 2023的一篇医学图像分割的文章《Decoupled Consistency for Semi-supervised Medical Image Segmentation》。这篇文章提出了一种新的解耦一致性半监督医学图像分割框架。该框架充分利用预测数据,将预测数据解耦为用于各种功能的数据,并最大限度地发挥每种功能的优势。

一、论文思路

对于半监督的医学分割任务,传统的伪标签方法会过滤掉低置信度的像素,而一致性正则化并没有充分利用高置信度和低置信度数据的优势。因此,这两种方法都不能充分利用无标签数据。这篇文章提出了一种新的解耦一致性半监督医学图像分割框架。首先,利用动态阈值将预测数据解耦为一致部分和不一致部分。对于一致部分,使用交叉伪监督的方法进行优化。对于不一致部分,进一步将其解耦为可能靠近决策边界的不可靠数据和更有可能出现在高密度区域的引导数据。不可靠数据将朝着引导数据的方向进行优化,这种操作为方向一致性。此外,为了充分利用数据,我们将特征图纳入训练过程并计算特征一致性的损失。

二、模型介绍

这篇文章的模型图如下图所示:

在这里插入图片描述

如该图所示,DC-Net包含一个编码器和两个一致的解码器,对于A解码器,用双线性插值进行上采样,对于B解码器使用反卷积进行上采样。对于有标签的数据,计算它们与真实值之间的损失 Lseg,对于一致部分,我们计算交叉伪监督损失 Lcps,对于不一致部分,我们计算方向一致性损失 Ldc,对于特征图,我们计算特征一致性损失 Lf。

三、细节分析

动态一致性阈值

FlexMatch [23] 证明了在训练的早期阶段,为了提高无标签数据的利用率并促进伪标签的多样化,γ 应该相对较小。随着训练的进行,γ 应该保持一个稳定的伪标签比例。因此,这篇文章的一致性阈值定义如下:

在这里插入图片描述

其中 B 是批量大小,λ 是随着训练增加的权重系数,我们设定 λ = t/tmax。为了采集更多的无标签数据,我们对 pA 和 pB 进行阈值评估,并选择较小的阈值作为我们的一致性阈值。我们将 λt 初始化为 1/C,其中 C 表示类别数。最终一致性阈值 γt 被定义并调整如下

在这里插入图片描述

分解一致性

这篇文章将不一致部分解耦为不可靠数据和引导数据,其中不可靠数据可能出现在决策边界,而引导数据更有可能出现在高密度区域。这两部分具有相同的索引信息,不同之处在于引导数据比不可靠数据更有信心。基于平滑假设,这两部分的输出应该是一致的,并且位于高密度区域。因此,我们应该集中优化决策边界周围的像素,以使其更接近高密度区域。我们首先通过锐化这些像素的置信度,使高置信度像素更接近高密度区域。以下是锐化过程:

在这里插入图片描述

其中 o 表示模型的输出,T ∈ (0, 1) 表示锐化温度。在实验中,我们设定 T = 0.5。通过比较 SpA 和 SpB,可以得到高置信度部分 hSpA, hSpB和低置信度部分 lSpA, lSpB。我们采用 L2 损失作为这一部分的损失函数。需要注意的是,只优化低置信度部分,不对高置信度部分的梯度进行反向传播。因此,方向一致性损失可以写成:

在这里插入图片描述

对于一致性部分:类似于 CPS,采用交叉伪监督的方法进行优化。具体细节如下:

在这里插入图片描述

PLA 和 PLB 代表相应的伪标签。 对于特征部分,这篇文章将特征图纳入训练过程中以进一步利用数据。为了减少计算量,我们对特征图进行了平均映射以降低其维度

在这里插入图片描述

映射过程如下:

在这里插入图片描述

其中,p > 1,fm 代表第 m 层的特征图,fmi 表示 fm 在通道维度上的第 i 个切片,f¯m 表示相应的映射结果。在实验中,我们设定 p = 2。我们的特征一致性损失如下:

在这里插入图片描述

其中 N 是 f¯mi 的像素数,n 是网络层数,f¯emi 和 f¯dmi 分别代表解码器和编码器第 m 层特征图的第 i 个像素。在本文中,仅使用了解码器 B 的特征图来计算损失。 最后,全局的损失为上面提到的分割loss和其他三个loss之和:

在这里插入图片描述

其中 Lseg 是应用于少量有标签数据的 Dice 损失。超参数 β 被设定为一个与迭代相关的预热函数 [29],β = e(−5(1− t/tmax)^2),实验中 λ = 1 − β,α = 0.3.

四、复现过程

在这里插入图片描述

这是我们下载下来的源码目录,然后这里作者提供了一个ACDC数据集在10%标注数据样本条件下的预训练模型,这里我们直接用这个模型进行测试,这里有一个问题就是作者将模型放在了ACDC_7这个目录下,我们只需要将其移动到ACDC_mcnet_kd_DCNet_7_labeled这个目录下就行了,然后运行test_acdc.py文件,就可以得到想要的结果。下图是实现的结果:

在这里插入图片描述

除了这个指标性能,我们还可以得到预测到的3D医学图像:

在这里插入图片描述

这里我们可以使用一个专门的3D图像可视化的软件ITK-SNAP,下面是一些可视化的结果:

在这里插入图片描述

在这里插入图片描述

这里在运行test_acdc.py文件之前,还需要做好数据集的准备,首先需要获取ACDC和PROMISE12数据集。这里可以我之前的一篇博客半监督的医学分割数据集(LA, Pancreas, ACDC和PROMISE12)分享,这篇博客提供了四个公开的数据集LA, PANCREAS, ACDC以及PROMISE12数据集,都是处理好的数据集,放对地方就可以直接运行代码。 下面我将展示一下核心代码,也添加了一些注释进行讲解说明:

output1_soft = F.softmax(output1, dim=1)
output2_soft = F.softmax(output2, dim=1)
output1_soft0 = F.softmax(output1 / 0.5, dim=1)
output2_soft0 = F.softmax(output2 / 0.5, dim=1)
# 这里是预测输出的锐化过程
with torch.no_grad():
    max_values1, _ = torch.max(output1_soft, dim=1)
    max_values2, _ = torch.max(output2_soft, dim=1)
    percent = (iter_num + 1) / max_iterations

    cur_threshold1 = (1 - percent) * cur_threshold + percent * max_values1.mean()
    cur_threshold2 = (1 - percent) * cur_threshold + percent * max_values2.mean()
    mean_max_values = min(max_values1.mean(), max_values2.mean())

    cur_threshold = min(cur_threshold1, cur_threshold2)
    cur_threshold = torch.clip(cur_threshold, 0.25, 0.95)

mask_high = (output1_soft > cur_threshold) & (output2_soft > cur_threshold)
mask_non_similarity = (mask_high == False)
# 这里是动态阈值部分的实现,这里阈值的初始值是0.25,也就是类别的倒数,然后这个值会快速地上升,最大值为0.95. 这里由这个阈值可以得到一致的高阈值区域和不一致区域。

new_output1_soft = torch.mul(mask_non_similarity, output1_soft)
new_output2_soft = torch.mul(mask_non_similarity, output2_soft)
high_output1 = torch.mul(mask_high, output1)
high_output2 = torch.mul(mask_high, output2)
high_output1_soft = torch.mul(mask_high, output1_soft)
high_output2_soft = torch.mul(mask_high, output2_soft)

pseudo_output1 = torch.argmax(output1_soft, dim=1)
pseudo_output2 = torch.argmax(output2_soft, dim=1)
pseudo_high_output1 = torch.argmax(high_output1_soft, dim=1)
pseudo_high_output2 = torch.argmax(high_output2_soft, dim=1)

max_output1_indices = new_output1_soft > new_output2_soft  # output1 距离近的像素的位置

max_output1_value0 = torch.mul(max_output1_indices, output1_soft0)
min_output2_value0 = torch.mul(max_output1_indices, output2_soft0)

max_output2_indices = new_output2_soft > new_output1_soft  # output2 距离远的像素的位置

max_output2_value0 = torch.mul(max_output2_indices, output2_soft0)
min_output1_value0 = torch.mul(max_output2_indices, output1_soft0)
# 上面这段代码就是利用一致性区域和非一致性区域的处理过程

loss_dc0 = 0
loss_cer = 0
loss_at_kd = criterion_att(encoder_features, decoder_features2)


loss_dc0 += mse_criterion(max_output1_value0.detach(), min_output2_value0)
loss_dc0 += mse_criterion(max_output2_value0.detach(), min_output1_value0)

loss_seg_dice += dice_loss(output1_soft[:labeled_bs, ...], label_batch[:labeled_bs].unsqueeze(1))
loss_seg_dice += dice_loss(output2_soft[:labeled_bs, ...], label_batch[:labeled_bs].unsqueeze(1))


if mean_max_values >= 0.95:
     loss_cer += ce_loss(output1, pseudo_output2.long().detach())
     loss_cer += ce_loss(output2, pseudo_output1.long().detach())
else:
     loss_cer += ce_loss(high_output1, pseudo_high_output2.long().detach())
     loss_cer += ce_loss(high_output2, pseudo_high_output1.long().detach())


consistency_weight = get_current_consistency_weight(iter_num // 150)
supervised_loss = loss_seg_dice
loss = supervised_loss + (1-consistency_weight) * (1000 * loss_at_kd) + consistency_weight * (1000 * loss_dc0 ) + 0.3 * loss_cer

部署方式

这里只需要按照这个要求安装所需要的pytorch和python版本即可。PyTorch 1.8.1, CUDA 10.1, and Python 3.6.13。除了这三个硬性要求之外,还需要安装一些其他的python库,但是只需要通过运行程序就可以直接缺少什么module没有安装,然后直接pip install即可。

文章代码资源点击附件获取