改造nnUnet,嵌入第三方网络

1,661 阅读2分钟

引言

nnUnet 是有监督的医学图像分割绕不开的话题,其卓越的性能和简易的方法,为相关研究者提供了一项强有力的工具。然而,由于高度封装性,在原先代码中嵌入自定义网络进行训练,并不是十分方便,本文旨在分享一点在使用 nnUnet 训练自定义网络过程中的一点经验,可能存在纰漏,欢迎在讨论区交流!

一、配置环境

1.1 硬件需求

nnUnet 的建议环境是Linux,若使用Windows,需修改路径相关代码(斜杠和反斜杠的替换),很麻烦(不推荐)。博主是在Ubuntu环境中使用Pycharm进行 nnUnet 的学习

1.2 软件需求

nnUnet 官方推荐的使用方法是在命令行,但这不方便初学者学习。为了使用Pycharm的调试功能,需修改两个文件的代码 nnunetv2/paths.pynnunetv2/run/run_training.py

1.2.1 数据集路径

位于 nnunetv2/paths.py 文件中,将三个变量路径修改为自己的路径。custom_ 是博主自己定义的文件,大家可以随意实现

from custom_ import custom_config
base = custom_config['base']
preprocessing_output_dir = custom_config['preprocessing_output_dir']
network_training_output_dir_base = custom_config['network_training_output_dir_base']

1.2.2 程序入口

位于 nnunetv2/run/run_training.py 文件中,这里nnUnet训练代码的入口。由于不是命令行调用方式,需要将parser.add_argument的传入参数修改,添加 “-” 并设置 default 值。

parser = argparse.ArgumentParser()
parser.add_argument("-network", default='2d')
parser.add_argument("-network_trainer", default='nnUNetTrainerV2')
parser.add_argument("-task", default='666', help="can be task name or task id")
parser.add_argument("-fold", default='0', help='0, 1, ..., 5 or 'all'')

二、构建网络

2.1 前置知识

nnUnet 默认使用深监督,意味着自定义网络输出应为一个列表形式。然而,在网络推理时,我们只需要最高分辨率的输出,不需要多层次输出。在nnUnet官方实现中,使用 deep_supervision 参数控制是否多层次输出。综上所述,自定义网络需要满足两个条件:

  • 支持多层次输出
  • 使用变量deep_supervision控制输出类型

2.2 嵌入自定义网络

2.2.1 包装网络

这里提供一种对已有网络包装的方法,仅供参考

import torch.nn as nn

class custom_net(nn.Module):

    def __init__(self,):
        super(custom_net, self).__init__()
        self.deep_supervision = True
        # 使用你自己的网络
        self.model = None

    def forward(self, x):
        output = self.model(x)
        if self.deep_supervision:
            return [output, ]
        else:
            return output

2.2.2 嵌入主框架

将自定义网络嵌套进主框架。打开文件 nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py

替换函数 build_network_architecture

def build_network_architecture(self, plans_manager:PlansManager,
                                   dataset_json,
                                   configuration_manager:ConfigurationManager,
                                   num_input_channels,
                                   enable_deep_supervision: bool = True) -> nn.Module:
        from dynamic_network_architectures.initialization.weight_init import InitWeights_He
        model = custom_net()
        model.apply(InitWeights_He(1e-2))
        return model

参考资料

[1] nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation
[2] nnUnet