2. MMSegmentation训练自己的数据集

1,397 阅读7分钟

1. 框架特点

  • 算法丰富:378预训练模型、27篇论文复现
  • 模型化设置:配置简单、容易拓展
  • 统一超参:大量消融实验、支持公平对比
  • 使用方便:训练工具、测试工具、推理工具

预训练文件地址

2. MMSegmentation 结构

2.1 configs

模型结构配置文件、数据集相关配置文件、学习率&优化器配置文件、运行环境配置文件。

mmsegmentation  
   |— configs                          # 配置文件 
   |     |— __base__                   ## 基配置文件 
   |     |     |— datasets             ### 数据集相关配置文件(*) 
   |     |     |— models               ### 模型相关配置文件(*)
   |     |     |— schedules            ### 训练日程如优化器,学习率等相关配置文件(*) 
   |     |     |— default_runtime.py   ### 运行相关的默认的设置(*) 
   |     |— swin                       ## 各个分割模型的配置文件,会引用 __base__ 的配置并做修改(*)  
   |     |— ...                         

_base_/models

模型结构配置文件:以 pspnet_r50-d8.py 为例。

注:pspnet_r50-d8.py 是指 backbone:resnet50;head:pspnet。

图片.png

  • 模型主体结构:

图片.png

  • backbone:

图片.png

  • decoder_head:

图片.png

  • auxilary_head:辅助解码头通过低层次特征(in_index=2)作为输入去产生一个 loss,并进行反向传播,可以鼓励主干网络学习更好的低层特征。主解码头和辅助解码头产生的 loss 共同优化可以得到更好的分割效果。当然我们可以看到辅助解码头的 loss 权重比较低。

图片.png

_base_/datasets

数据集配置文件:以 cityscapes.py 为例。

微信截图_20230723194415.png

微信截图_20230723194544.png

_base_/schedules

优化器配置文件:以 schedule_20k.py 为例。

微信截图_20230723200617.png

swin、ann、apcnet...

总配置文件:train.py 中的必须参数 config。会加载base中的包括模型配置文件地址、数据集配置文件地址、优化器配置文件地址等:以 unet/unet-s5-d16_deeplabv3_4xb4-40k_drive-64x64.py 为例。

_base_ = [
    '../_base_/models/deeplabv3_unet_s5-d16.py', '../_base_/datasets/drive.py',
    '../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py'
]
crop_size = (64, 64)
data_preprocessor = dict(size=crop_size)
model = dict(
    data_preprocessor=data_preprocessor,
    test_cfg=dict(crop_size=(64, 64), stride=(42, 42)))

2.2 mmseg

models:模型各个组件实现代码。

datasets:数据加载器构建代码。

— mmseg  
   |     |— evaluation                 ## 评估模型代码 
   |     |     |— metrics              ### 评估模型性能代码 
   |     |— datasets                   ## 数据集相关代码 
   |     |     |— pipelines            ### 数据预处理代码 
   |     |     |— samplers             ### 数据集采样代码 
   |     |     |— ade.py               ### 各个数据集准备需要的代码(*)
   |     |     |— ... 
   |     |— models                     ## 分割模型具体实现代码 
   |     |     |— backbones            ### 主干网络 
   |     |     |— decode_heads         ### 解码头 
   |     |     |— losses               ### 损失函数 
   |     |     |— necks                ### 颈 
   |     |     |— segmentors           ### 构建完整分割网络的代码 
   |     |     |— utils                ### 构建模型时的辅助工具 
   |     |— apis                       ## high level 用户接口,在这里调用 ./mmseg/ 内各个组件 
   |     |     |— train.py             ### 训练接口 
   |     |     |— test.py              ### 测试接口 
   |     |     |— ... 
   |     |— ops                        ## cuda 算子(即将迁移到 mmcv 中) 
   |     |— utils                      ## 辅助工具 

2.3 tools

训练、测试脚本;模型、数据集转换的脚本。

— tools 
   |     |— model_converters           ## 各个主干网络预训练模型转 key 脚本 
   |     |— datasets_convert           ## 各个数据集准备转换脚本 
   |     |— train.py                   ## 训练脚本 
   |     |— test.py                    ## 测试脚本 
   |     |— ...                       

3. 训练自己的数据集

参考官方文档

step 1. 运行通和自己数据集类似的官方数据集

我这边数据集和官方数据集 CHASE DB1 比较相似,Image 和 Mask 都是图片。

因此先运行通 CHASE DB1 数据集:

python tools/train.py /home/hwz/mmsegmentation-main/configs/unet/unet_s5-d16_deeplabv3_4xb4-40k_chase-db1-128x128.py --work-dir /home/hwz/mmsegmentation-main/mmseg_log

step 2. 根据官方数据集格式修改自己数据集格式

step 2.1 将自己的数据集修改成与 CHASE DB1 同样的文件目录

— CHASE_DB1 
   |     |— annotations     # Mask 图片                   
   |     |     |— training  
   |     |     |— validation  
   |     |— images          # 图片         
   |     |     |— training           
   |     |     |— validation             

step 2.2 修改成与 CHASE DB1 对应的图片格式

检验PIL图像的通道数和像素值:

from PIL import Image

image_path ='D:/work/Competition/mmsegmentation-main/眼球/Image_01L_1stHO.png'

image = Image.open(image_path) # 读取PIL图像
channels = image.mode # 图像的通道数
print(f"图片 '{image_path}' 是 {channels} 通道的图像。")

pixel_counts = image.getcolors() # 获取图像中的像素值和计数
for count, color in pixel_counts: # 打印不同像素值的计数和颜色值
    print(f"像素值: {color}, 计数: {count}")
  • RGB通道图片:3通道(如 640×640×3),每个通道像素范围是 [0 , 255]。

  • L通道图片:单通道(如 640×640),该通道像素范围是 [0 , 255]。

  • 1通道图片:单通道(如 640×640),该通道像素只有 0/255 两种。

经检验:牙齿Image为 RGB 通道,Mask为 L 通道;眼球Image为 RGB 通道,Mask为 1 通道。

因此需要对牙齿的 Mask 图片批量变换:

import os 
from PIL import Image 
def convert_to_binary(image_path, threshold): 
    image = Image.open(image_path).convert('L') # 打开图像并转换为灰度图像 
    binary_image = image.point(lambda x: 255 if x > threshold else 0, '1') # 根据阈值进行二值化处理 
    return binary_image 

def batch_convert_images(folder_path, output_folder, threshold): 
    if not os.path.exists(output_folder): 
        os.makedirs(output_folder) 

    image_files = os.listdir(folder_path) 
    for image_file in image_files: 
        image_path = os.path.join(folder_path, image_file) 
        if os.path.isfile(image_path): 
            binary_image = convert_to_binary(image_path, threshold) 
            output_path = os.path.join(output_folder, image_file) 
            binary_image.save(output_path) # 保存二值图像 
    print("图片批量处理完成!") 

folder_path = 'input_folder' # 输入文件夹路径 
output_folder = 'output_folder' # 输出文件夹路径 
threshold = 128 # 设置阈值,灰度值大于阈值的像素设为白色,灰度值小于等于阈值的像素设为黑色 
batch_convert_images(folder_path, output_folder, threshold)

经检验:转换后的牙齿的 Mask 图片为 1 通道,像素值只有 0 和 255。

step 3. 修改配置文件

step 3.1 创建一个新文件 mmseg/datasets/tooth.py

注:classes 和 mask 中的类别对应(虽然只有tooth这1个类别,但是background也必须算作1个类别);palette 和 mask 中类别的色彩对应(可以用ps查看每个类别的RGB色彩)。

# Copyright (c) OpenMMLab. All rights reserved.
import mmengine.fileio as fileio

from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset


@DATASETS.register_module()
class ToothDataset(BaseSegDataset):
    METAINFO = dict(
        classes=('background', 'tooth'),
        palette=[[0, 0, 0], [255, 255, 255]])

    def __init__(self,
                 img_suffix='.png',
                 seg_map_suffix='.png',
                 reduce_zero_label=False,
                 **kwargs) -> None:
        super().__init__(
            img_suffix=img_suffix,
            seg_map_suffix=seg_map_suffix,
            reduce_zero_label=reduce_zero_label,
            **kwargs)
        assert fileio.exists(
            self.data_prefix['img_path'], backend_args=self.backend_args)

step 3.2mmseg/datasets/__init__.py 中添加语句

from .tooth import ToothDataset

# __all__中添加
__all__ = ['ToothDataset']

step 3.3 创建一个新的数据集配置文件 configs/__base__/datasets/tooth.py

# dataset settings
dataset_type = 'ToothDataset' # 与 mmseg/datasets/tooth.py中的类名是对应的
data_root = 'data/tooth/' # 自己的数据集所在的位置
img_scale = (320, 640)
crop_size = (160, 320)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(
        type='RandomResize',
        scale=img_scale,
        ratio_range=(0.5, 2.0),
        keep_ratio=True),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='PackSegInputs')
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', scale=img_scale, keep_ratio=True),
    # add loading annotation after ``Resize`` because ground truth
    # does not need to do resize data transform
    dict(type='LoadAnnotations'),
    dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
    dict(type='LoadImageFromFile', backend_args=None),
    dict(
        type='TestTimeAug',
        transforms=[
            [
                dict(type='Resize', scale_factor=r, keep_ratio=True)
                for r in img_ratios
            ],
            [
                dict(type='RandomFlip', prob=0., direction='horizontal'),
                dict(type='RandomFlip', prob=1., direction='horizontal')
            ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
        ])
]

train_dataloader = dict(
    batch_size=4,
    num_workers=2,
    persistent_workers=True,
    sampler=dict(type='InfiniteSampler', shuffle=True),
    dataset=dict(
        type='RepeatDataset',
        times=40000,
        dataset=dict(
            type=dataset_type,
            data_root=data_root,
            data_prefix=dict(
                img_path='images/training',
                seg_map_path='annotations/training'),
            pipeline=train_pipeline)))

val_dataloader = dict(
    batch_size=1,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(
            img_path='images/validation',
            seg_map_path='annotations/validation'),
        pipeline=test_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(type='IoUMetric', iou_metrics=['mDice'])
test_evaluator = val_evaluator

step 3.4mmseg/utils/class_names 中补充数据集元信息

def tooth_classes():
    return [
        'background','tooth'
    ]

def tooth_palette():
    return [
        [0,0,0],[255,255,255]
    ]
    
# dataset_aliases中添加
dataset_aliases={
'tooth':['tooth']
}

step 3.5 创建一个总配置文件 configs/unet/unet_s5-d16_deeplabv3_4xb4-40k_tooth-320×640.py

_base_ = [
    '../_base_/models/deeplabv3_unet_s5-d16.py',
    '../_base_/datasets/tooth.py', '../_base_/default_runtime.py',
    '../_base_/schedules/schedule_40k.py'
]
crop_size = (160, 320)
data_preprocessor = dict(size=crop_size)
model = dict(
    data_preprocessor=data_preprocessor,
    test_cfg=dict(crop_size=(160, 320), stride=(85, 85)))

修改模型配置文件 _base_/models/deeplabv3_unet_s5-d16.py

注:num_classes按照类别修改。

# model settings
norm_cfg = dict(type='BN', requires_grad=True) # 单卡训练为BN,多卡训练为SyncBN
data_preprocessor = dict(
    type='SegDataPreProcessor',
    mean=[123.675, 116.28, 103.53],
    std=[58.395, 57.12, 57.375],
    bgr_to_rgb=True,
    pad_val=0,
    seg_pad_val=255)
model = dict(
    type='EncoderDecoder',
    data_preprocessor=data_preprocessor,
    pretrained=None,
    backbone=dict(
        type='UNet',
        in_channels=3,
        base_channels=64,
        num_stages=5,
        strides=(1, 1, 1, 1, 1),
        enc_num_convs=(2, 2, 2, 2, 2),
        dec_num_convs=(2, 2, 2, 2),
        downsamples=(True, True, True, True),
        enc_dilations=(1, 1, 1, 1, 1),
        dec_dilations=(1, 1, 1, 1),
        with_cp=False,
        conv_cfg=None,
        norm_cfg=norm_cfg,
        act_cfg=dict(type='ReLU'),
        upsample_cfg=dict(type='InterpConv'),
        norm_eval=False),
    decode_head=dict(
        type='ASPPHead',
        in_channels=64,
        in_index=4,
        channels=16,
        dilations=(1, 12, 24, 36),
        dropout_ratio=0.1,
        num_classes=2, # 类别记得符合Mask的类别数
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
    auxiliary_head=dict(
        type='FCNHead',
        in_channels=128,
        in_index=3,
        channels=64,
        num_convs=1,
        concat_input=False,
        dropout_ratio=0.1,
        num_classes=2, # 类别记得符合Mask的类别数
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
    # model training and testing settings
    train_cfg=dict(),
    test_cfg=dict(mode='slide', crop_size=128, stride=85))

step 4. 重新启动

python setup.py install
pip install -v -e .

step 5. 启动训练和测试

python tools/train.py /home/hwz/mmsegmentation-main/configs/unet/unet_s5-d16_deeplabv3_4xb4-40k_tooth-320×640.py --work-dir /home/hwz/mmsegmentation-main/mmseg_log

4. 推理自己的数据集

推理的代码可以对demo/image_demo.py改造

注:一定要在GPU推理,不然会报错!

# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser

from mmengine.model import revert_sync_batchnorm

from mmseg.apis import inference_model, init_model, show_result_pyplot

import os

import tqdm

def main():
    parser = ArgumentParser()
    parser.add_argument('--img', default="/home/hwz/mmsegmentation-main/data/tooth/test/image",help='Image file')
    parser.add_argument('--config',default="/home/hwz/mmsegmentation-main/configs/unet/unet_s5-d16_deeplabv3_4xb4-40k_tooth-320×640.py", help='Config file')
    parser.add_argument('--checkpoint',default="/home/hwz/mmsegmentation-main/mmseg_log/iter_40000.pth", help='Checkpoint file')
    parser.add_argument('--out-file', default="/home/hwz/mmsegmentation-main/mmseg_show", help='Path to output file')
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference')
    parser.add_argument(
        '--opacity',
        type=float,
        default=0.5,
        help='Opacity of painted segmentation map. In (0, 1] range.')
    parser.add_argument(
        '--title', default='result', help='The image identifier.')
    args = parser.parse_args()

    # build the model from a config file and a checkpoint file
    model = init_model(args.config, args.checkpoint, device=args.device)
    if args.device == 'cpu':
        model = revert_sync_batchnorm(model)


    for filename in os.listdir(args.img):
        img = os.path.join(args.img, filename)
        print(img)
        out_file = os.path.join(args.out_file, filename)
        print(out_file)
        result = inference_model(model, img)
        show_result_pyplot(
            model,
            img,
            result,
            title=args.title,
            opacity=args.opacity,
            draw_gt=False,
            show=False if args.out_file is not None else True,
            out_file=out_file)


if __name__ == '__main__':
    main()

推理完,可以根据具体要求对图像进行操作:我推理出来是RGB通道图像,我将其改为1通道。

报错:

图片.png

出现这种情况就把训练的时候: train_dataloader = dict( batch_size=1,...) 将batch_size=1改成batch_size=2。