引言
nnUnet 是有监督的医学图像分割绕不开的话题,其卓越的性能和简易的方法,为相关研究者提供了一项强有力的工具。然而,由于高度封装性,在原先代码中嵌入自定义网络进行训练,并不是十分方便,本文旨在分享一点在使用 nnUnet 训练自定义网络过程中的一点经验,可能存在纰漏,欢迎在讨论区交流!
一、配置环境
1.1 硬件需求
nnUnet 的建议环境是Linux,若使用Windows,需修改路径相关代码(斜杠和反斜杠的替换),很麻烦(不推荐)。博主是在Ubuntu环境中使用Pycharm进行 nnUnet 的学习
1.2 软件需求
nnUnet 官方推荐的使用方法是在命令行,但这不方便初学者学习。为了使用Pycharm的调试功能,需修改两个文件的代码 nnunetv2/paths.py 和 nnunetv2/run/run_training.py
1.2.1 数据集路径
位于 nnunetv2/paths.py 文件中,将三个变量路径修改为自己的路径。custom_ 是博主自己定义的文件,大家可以随意实现
from custom_ import custom_config
base = custom_config['base']
preprocessing_output_dir = custom_config['preprocessing_output_dir']
network_training_output_dir_base = custom_config['network_training_output_dir_base']
1.2.2 程序入口
位于 nnunetv2/run/run_training.py 文件中,这里nnUnet训练代码的入口。由于不是命令行调用方式,需要将parser.add_argument的传入参数修改,添加 “-” 并设置 default 值。
parser = argparse.ArgumentParser()
parser.add_argument("-network", default='2d')
parser.add_argument("-network_trainer", default='nnUNetTrainerV2')
parser.add_argument("-task", default='666', help="can be task name or task id")
parser.add_argument("-fold", default='0', help='0, 1, ..., 5 or 'all'')
二、构建网络
2.1 前置知识
nnUnet 默认使用深监督,意味着自定义网络输出应为一个列表形式。然而,在网络推理时,我们只需要最高分辨率的输出,不需要多层次输出。在nnUnet官方实现中,使用 deep_supervision 参数控制是否多层次输出。综上所述,自定义网络需要满足两个条件:
- 支持多层次输出
- 使用变量deep_supervision控制输出类型
2.2 嵌入自定义网络
2.2.1 包装网络
这里提供一种对已有网络包装的方法,仅供参考
import torch.nn as nn
class custom_net(nn.Module):
def __init__(self,):
super(custom_net, self).__init__()
self.deep_supervision = True
# 使用你自己的网络
self.model = None
def forward(self, x):
output = self.model(x)
if self.deep_supervision:
return [output, ]
else:
return output
2.2.2 嵌入主框架
将自定义网络嵌套进主框架。打开文件 nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
替换函数 build_network_architecture
def build_network_architecture(self, plans_manager:PlansManager,
dataset_json,
configuration_manager:ConfigurationManager,
num_input_channels,
enable_deep_supervision: bool = True) -> nn.Module:
from dynamic_network_architectures.initialization.weight_init import InitWeights_He
model = custom_net()
model.apply(InitWeights_He(1e-2))
return model
参考资料
[1] nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation
[2] nnUnet