【超分实战】拒绝灾难性遗忘!记一次原生4K医疗影像(SurgiSR4K)的模型微调踩坑实录

14 阅读13分钟

为什么学这个?

最近在做医疗影像的超分辨率(Super-Resolution, SR)任务。我们都知道,深度学习模型如果直接在一个垂直领域的小数据集上从头训练,往往很难收敛,效果也不尽如人意。标准的做法是:先在一个大规模通用数据集(如 DIV2K)上预训练,让模型学会提取直线、角点、色彩过渡等“基础高频特征”,然后再迁移到我们特定的医疗数据集(我使用的是最新开源的原生 4K 内窥镜数据集 SurgiSR4K)上进行微调(Fine-tuning)

微调听起来很简单,无非就是加载一下 .pth 权重接着跑。但在实际落地中,为了保证实验的绝对可复现性以及无损的特征迁移,我踩了不少坑。这篇文章就来复盘一下我是如何从零构建一个稳健的微调流水线的。


核心内容与步骤

我的整体思路是将“预训练”和“微调”的逻辑彻底解耦,单独编写了一个 finetune.py 脚本。

1. 严格锁定随机种子(保证可复现)

做算法实验,不可复现是大忌。在代码的开头,我写了一个死锁所有随机性的函数,确保只要传入相同的 seed,每次微调的 loss 曲线必须完全一致:

Python

def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    cudnn.deterministic = True
    cudnn.benchmark = False

2. 动态数据集加载与预处理对齐

预训练时,为了追求极致的 I/O 速度,我提前将通用数据集切分成了静态的 .h5 文件。但在微调阶段,为了更好的数据扩增效果,我改为直接读取重新组织好(划分了 train/val/test)的图片文件夹,并在 DataLoader 中实时进行联合随机裁剪(Joint Random Crop)

关键点:通道严格对齐!

在仔细盘点预训练阶段的 prepare.py 脚本时,我惊出了一身冷汗,发现了一个极易致命的细节——预训练模型完全是仅在 Y 通道(亮度通道)上进行训练的!

很多初次接触超分算法(如 ESPCN、SRCNN)的开发者可能会疑惑:为什么放着好好的 RGB 三通道不用,非要大费周章地只提 Y 通道来训练?

这其实是基于**人类视觉系统(HVS)**的生理特性做出的极致优化:

  1. 人眼对“结构”比对“色彩”更敏感:人眼对图像的亮度(Y通道,几乎包含了所有的边缘、轮廓、纹理等高频结构信息)极其敏感,而对色度/饱和度(Cb、Cr通道)的微小模糊则相对迟钝。

  2. 算力与效果的完美平衡:如果同时对 RGB 三通道做超分辨率重建,计算量和显存开销会直接翻三倍。因此,业内极其经典且高效的做法是:只把最难的高频亮度特征(Y通道)交给神经网络去精雕细琢,而对于 Cb、Cr 通道,只需使用极低成本的双三次插值(Bicubic)直接放大即可。 最后将它们合并转回 RGB,肉眼几乎看不出色彩瑕疵,但速度却快得多。

明白了这一层底层算法逻辑,再回到我们的微调代码上:预训练网络的第一层卷积是为 1通道 输入量身定制的,如果微调时我没注意,直接把原图的 RGB 三通道 喂进去,由于张量维度不匹配,模型会瞬间抛出 Shape Mismatch 错误并当场崩溃。

为了规避这个大坑,我在新的 Dataset 类中严谨地重写了转换逻辑。我强制将每张图片先转为 YCbCr 色彩空间,并剥离出 Y 通道,确保喂给网络的数据与预训练时的特征空间达到绝对的、像素级的对齐:

Python

def rgb_to_y(self, img):
    """提取 Y 通道,严格对齐预训练模型的输入特征空间"""
    ycbcr = img.convert('YCbCr')
    y, cb, cr = ycbcr.split()
    return y

3. 制定微调超参数策略

基于预训练的参数(Scale=2, LR=1e-5, BatchSize=16, Epochs=200),我为微调阶段制定了如下策略:

  • 放大倍率 (Scale) = 2:绝对不能变!网络末端的上采样层权重尺寸是和 Scale 绑定的,改了直接报 Shape Mismatch。
  • 学习率 (LR) = 1e-6:将预训练的最终学习率砍半(甚至可以缩小到 1/10)。只做微小调整,适应医疗图像中组织黏膜、血管的纹理。
  • 训练轮数 (Epochs) = 80:微调起点极高,通常 50-100 轮即可在特定数据集上收敛,跑多了反而会在小数据集上过拟合。配合 Early Stopping 逻辑保存 best_finetuned.pth

遇到的坑点与排雷指南

在调试过程中,我总结了微调阶段最容易踩的 4 个大坑:

坑点一:灾难性遗忘(Catastrophic Forgetting)

  • 现象:加载权重后,头几个 Epoch 的 Loss 突然爆炸,PSNR 断崖式下跌。
  • 原因:学习率设置过大,巨大的梯度更新瞬间摧毁了模型在预训练时好不容易学到的底层通用特征。
  • 解法:微调学习率必须远小于预训练学习率。

坑点二:无脑加载 Optimizer 状态

  • 现象:Loss 降不下去,收敛方向极其诡异。
  • 原因:像 Adam 这样的优化器内部会保存历史动量(Momentum)和方差信息。如果换了全新的数据集还加载旧的优化器状态,历史动量会把模型往错误的方向“带偏”。
  • 解法:迁移数据集微调时,只加载 model.state_dict() ,不要加载 optimizer.state_dict(),让优化器以新的小学习率重新初始化。

坑点三:Patch Size(裁剪大小)缩水导致感受野丢失

  • 疑惑:微调时,输入图片的裁剪大小可以随便改吗?
  • 正解强烈建议保持一致或适度调大,绝不能变小。 模型在预训练时已经习惯了在固定大小的窗口(如 32x32)内寻找纹理关联。如果微调时切成了 16x16,视野变小,全局特征聚合能力就会失效。

坑点四:不敢调整 Batch Size

  • 疑惑:换了 4K 数据集后显存吃紧,微调可以减小 Batch Size 吗?
  • 正解完全可以。 传统的分类网络(如 ResNet)因为有 Batch Normalization (BN) 层,强行缩小 Batch 会导致统计量崩塌。但 SR 网络(如我用的 ESPCN)通常没有 BN 层,Batch Size 缩小到 8 甚至 4 并不影响内部特征分布,只需把学习率同步调小一点防止梯度震荡即可。

收获与总结

这次实战让我深刻体会到,深度学习的微调绝不是简单的 torch.load。它要求我们不仅要洞悉底层网络结构(如 BN 层的有无、感受野的大小),还要对数据流水线(如预处理通道、归一化方式)有近乎苛刻的像素级把控。

SurgiSR4K 这样的高质量医疗 4K 数据集非常难得,希望这套“无损微调”的方法论也能帮到正在做医疗影像算法的同行们。

以下是为你准备的文章附录部分。你可以直接将这部分追加到刚才那篇博客的末尾。

我已经将我们在讨论中得出的最佳实践(如 scale=2、学习率 5e-6、提取 Y 通道等)全部更新到了这份最终版本的代码中。


附录:finetune.py 完整源码

这里放出我用于微调的完整脚本,代码中包含了严格的随机种子控制、Y 通道提取、联合数据增强以及最佳模型保存逻辑。大家可以基于此代码直接跑在自己的 SurgiSR4K 数据集上。

TF.rotate 的插值风险 (强烈建议优化)

# 随机旋转 0, 90, 180, 270 度
angle = random.choice([0, 90, 180, 270])
if angle != 0:
lr_img = TF.rotate(lr_img, angle)
hr_img = TF.rotate(hr_img, angle)

你目前使用的是 torchvision.transforms.functional.rotate 来进行 90、180、270 度的旋转。这是一个通用旋转函数。 在超分任务中,强烈建议不要对图像使用任何涉及插值计算的旋转函数,即使你指定的是 90 度的倍数。通用旋转底层可能会调用仿射变换矩阵,导致原始像素值被重新采样,从而引发极微小的模糊,破坏高频纹理。

# ======== # 随机旋转 90, 180, 270 度 (使用无损转置) # =============
rot_choice = random.choice([None, Image.ROTATE_90, 
Image.ROTATE_180, Image.ROTATE_270]) if rot_choice is not None: 
lr_img = lr_img.transpose(rot_choice) 
hr_img = hr_img.transpose(rot_choice)

优化方案:使用纯粹的矩阵转置来代替旋转。 PIL 提供了 transpose 方法,对于 90 度的整数倍旋转,它只是在内存中重排像素位置,完全不涉及任何数学插值,是 100% 无损且对齐的: Python

import argparse
import os
import random
import re
import numpy as np
from pathlib import Path
from PIL import Image
import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

# 导入你原有的模型和工具函数
from models import ESPCN_RDB
from utils import AverageMeter, calc_psnr


# ==========================================
# 1. 保证可复现性的种子设置
# ==========================================
def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    cudnn.deterministic = True
    cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)
    
# [新增] DataLoader 多线程子进程的随机种子初始化函数
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

# ==========================================
# 2. 适配 organized 目录和 Y 通道的 Dataset
# ==========================================
class SurgiSR4KDataset(Dataset):
    def __init__(self, data_root, split='train', lr_res="1920x1080p", hr_res="3840x2160p",
                 scale_factor=2, lr_patch_size=32, is_train=True):
        self.data_root = Path(data_root)
        self.split = split
        self.lr_res = lr_res
        self.hr_res = hr_res
        self.scale_factor = scale_factor
        self.lr_patch_size = lr_patch_size
        self.is_train = is_train

        self.lr_dir = self.data_root / self.split / self.lr_res
        self.hr_dir = self.data_root / self.split / self.hr_res

        if not self.lr_dir.exists() or not self.hr_dir.exists():
            raise FileNotFoundError(f"找不到数据目录: {self.lr_dir}{self.hr_dir}")

        self.lr_image_paths = sorted(list(self.lr_dir.rglob("*.png")))
        if len(self.lr_image_paths) == 0:
            raise ValueError(f"在 {self.lr_dir} 中未找到任何图像!")

    def __len__(self):
        return len(self.lr_image_paths)

    def rgb_to_y(self, img):
        ycbcr = img.convert('YCbCr')
        y, cb, cr = ycbcr.split()
        return y

    def __getitem__(self, idx):
        lr_path = self.lr_image_paths[idx]

        rel_path = lr_path.relative_to(self.lr_dir)
        hr_rel_path_str = str(rel_path).replace(self.lr_res, self.hr_res)
        hr_path = self.hr_dir / hr_rel_path_str

        lr_img = Image.open(lr_path).convert("RGB")
        hr_img = Image.open(hr_path).convert("RGB")

        if self.is_train:
            lr_w, lr_h = lr_img.size
            lr_x = random.randint(0, lr_w - self.lr_patch_size)
            lr_y = random.randint(0, lr_h - self.lr_patch_size)

            hr_x = lr_x * self.scale_factor
            hr_y = lr_y * self.scale_factor
            hr_patch_size = self.lr_patch_size * self.scale_factor

            lr_img = lr_img.crop((lr_x, lr_y, lr_x + self.lr_patch_size, lr_y + self.lr_patch_size))
            hr_img = hr_img.crop((hr_x, hr_y, hr_x + hr_patch_size, hr_y + hr_patch_size))

            if random.random() < 0.5:
                lr_img = TF.hflip(lr_img)
                hr_img = TF.hflip(hr_img)
            if random.random() < 0.5:
                lr_img = TF.vflip(lr_img)
                hr_img = TF.vflip(hr_img)

            # 随机旋转 0, 90, 180, 270 度 (使用 PIL 的无损 transpose 矩阵转置)
            rot_choice = random.choice([0, 90, 180, 270])

            # 兼容新老版本 Pillow 的调用方式
            # Pillow 10+ 推荐使用 Image.Transpose.ROTATE_90,旧版本使用 Image.ROTATE_90
            if rot_choice == 90:
                lr_img = lr_img.transpose(Image.ROTATE_90 if hasattr(Image, 'ROTATE_90') else Image.Transpose.ROTATE_90)
                hr_img = hr_img.transpose(Image.ROTATE_90 if hasattr(Image, 'ROTATE_90') else Image.Transpose.ROTATE_90)
            elif rot_choice == 180:
                lr_img = lr_img.transpose(
                    Image.ROTATE_180 if hasattr(Image, 'ROTATE_180') else Image.Transpose.ROTATE_180)
                hr_img = hr_img.transpose(
                    Image.ROTATE_180 if hasattr(Image, 'ROTATE_180') else Image.Transpose.ROTATE_180)
            elif rot_choice == 270:
                lr_img = lr_img.transpose(
                    Image.ROTATE_270 if hasattr(Image, 'ROTATE_270') else Image.Transpose.ROTATE_270)
                hr_img = hr_img.transpose(
                    Image.ROTATE_270 if hasattr(Image, 'ROTATE_270') else Image.Transpose.ROTATE_270)


        lr_y = self.rgb_to_y(lr_img)
        hr_y = self.rgb_to_y(hr_img)

        lr_tensor = TF.to_tensor(lr_y)
        hr_tensor = TF.to_tensor(hr_y)

        return lr_tensor, hr_tensor


# ==========================================
# 3. 微调主循环
# ==========================================
def main():
    parser = argparse.ArgumentParser(description="SurgiSR4K Fine-tuning Script")
    parser.add_argument('--data_root', type=str, required=True, help='Organized 数据集的根目录路径')
    parser.add_argument('--pretrained_weights', type=str, required=True, help='预训练模型的 .pth 文件路径')
    parser.add_argument('--scale', type=int, default=2, help='超分辨率放大倍数 (SurgiSR4K 从 1920 到 3840 是 2 倍)')
    parser.add_argument('--lr_patch_size', type=int, default=128, help='输入给网络的 LR 裁剪大小 (HR对应为 128*2=256)')
    parser.add_argument('--batch_size', type=int, default=16, help='Batch Size')
    parser.add_argument('--num_epochs', type=int, default=200, help='微调轮数')
    parser.add_argument('--lr', type=float, default=1e-6, help='微调初始学习率')
    parser.add_argument('--num_workers', type=int, default=4, help='DataLoader 线程数 (Windows 建议设为 0)')
    parser.add_argument('--seed', type=int, default=42, help='随机种子')

    # [新增] 学习率衰减和早停机制的相关参数
    parser.add_argument('--lr_patience', type=int, default=10, help='连续多少轮PSNR不提升,则学习率减半')
    parser.add_argument('--patience', type=int, default=25, help='连续多少轮PSNR不提升,则触发早停机制')
    parser.add_argument('--min_lr', type=float, default=1e-8, help='学习率衰减的下限')

    args = parser.parse_args()

    pretrained_abspath = os.path.abspath(args.pretrained_weights)
    outputs_dir = os.path.dirname(pretrained_abspath)

    pretrained_basename = os.path.basename(args.pretrained_weights)
    pretrained_name, ext = os.path.splitext(pretrained_basename)
    best_save_filename = f"{pretrained_name}_finetuned{ext}"
    latest_save_filename = f"{pretrained_name}_latest{ext}"

    # 利用正则表达式解析模型参数
    parsed_kwargs = {}
    match_gc = re.search(r'growth_channels_(\d+)', pretrained_name)
    if match_gc:
        parsed_kwargs['growth_channels'] = int(match_gc.group(1))

    match_rdb = re.search(r'RDB_(\d+)', pretrained_name)
    if match_rdb:
        parsed_kwargs['rdb_layers'] = int(match_rdb.group(1))

    match_attn = re.search(r'Attn_(pixel|weakened_pixel|none)', pretrained_name)
    if match_attn:
        parsed_kwargs['attention_type'] = match_attn.group(1)

    activation_types = ['ReLU', 'LeakyReLU', 'PReLU', 'Tanh', 'Sigmoid', 'GELU']
    for act in activation_types:
        if f"_{act}_" in pretrained_name or pretrained_name.endswith(f"_{act}"):
            parsed_kwargs['activation'] = act
            break

    print(f"==> 从文件名 [{pretrained_basename}] 中解析到以下超参数:")
    for k, v in parsed_kwargs.items():
        print(f"    - {k}: {v}")

    set_random_seed(args.seed)

    os.makedirs(outputs_dir, exist_ok=True)
    writer = SummaryWriter(log_dir=os.path.join(outputs_dir, 'logs_finetune'))
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    print(f"==> 初始化模型 (Scale: x{args.scale}) ...")
    model = ESPCN_RDB(scale_factor=args.scale, **parsed_kwargs).to(device)

    print(f"==> 正在加载预训练权重: {args.pretrained_weights}")
    checkpoint = torch.load(args.pretrained_weights, map_location=device, weights_only=False)

    if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
        state_dict = checkpoint['model_state_dict']
    else:
        state_dict = checkpoint

    clean_state_dict = {}
    for k, v in state_dict.items():
        if "total_ops" in k or "total_params" in k:
            continue
        clean_state_dict[k] = v

    model.load_state_dict(clean_state_dict, strict=True)
    print(f"==> 预训练权重加载成功!微调后的模型将保存在: {outputs_dir}")

    print("==> 构建数据集...")
    train_dataset = SurgiSR4KDataset(
        data_root=args.data_root, split='train',
        lr_res="1920x1080p", hr_res="3840x2160p",
        scale_factor=args.scale, lr_patch_size=args.lr_patch_size, is_train=True
    )
    val_dataset = SurgiSR4KDataset(
        data_root=args.data_root, split='val',
        lr_res="1920x1080p", hr_res="3840x2160p",
        scale_factor=args.scale, is_train=False
    )

    # [新增] 为 DataLoader 创建一个固定种子的生成器
    g = torch.Generator()
    g.manual_seed(args.seed)

    # [修改] 把 worker_init_fn 和 generator 传给 DataLoader
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, pin_memory=True,
                              worker_init_fn=seed_worker, generator=g)

    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False,
                            num_workers=args.num_workers, pin_memory=True,
                            worker_init_fn=seed_worker, generator=g)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    criterion_mse = nn.MSELoss().to(device)

    # [新增] 学习率自动衰减调度器 (基于验证集指标)
    # mode='max' 表示当监控的指标(PSNR)不再增大时进行衰减
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=args.lr_patience,
        min_lr=args.min_lr, verbose=True
    )

    best_psnr = 0.0
    best_epoch = 0
    early_stop_counter = 0  # [新增] 早停计数器

    print("==> 开始微调...")
    for epoch in range(args.num_epochs):
        model.train()
        epoch_losses = AverageMeter()

        # 获取当前学习率,仅用于展示
        current_lr = optimizer.param_groups[0]['lr']

        with tqdm(total=len(train_loader.dataset)) as t:
            t.set_description(f'Epoch: {epoch}/{args.num_epochs - 1} [LR: {current_lr:.2e}]')
            for inputs, labels in train_loader:
                inputs, labels = inputs.to(device), labels.to(device)

                preds = model(inputs)
                loss = criterion_mse(preds, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                epoch_losses.update(loss.item(), len(inputs))
                t.set_postfix(loss=f'{epoch_losses.avg:.6f}')
                t.update(len(inputs))

        model.eval()
        eval_psnr = AverageMeter()
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                preds = model(inputs).clamp(0.0, 1.0)
                psnr = calc_psnr(preds, labels)
                eval_psnr.update(psnr.item(), len(inputs))

        print(f'Eval PSNR: {eval_psnr.avg:.2f}dB')
        writer.add_scalar('Loss/train', epoch_losses.avg, epoch)
        writer.add_scalar('PSNR/val', eval_psnr.avg, epoch)
        writer.add_scalar('LR', current_lr, epoch)

        # [新增] 记录最新状态,方便中断后恢复
        save_state = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'best_psnr': max(best_psnr, eval_psnr.avg)
        }
        torch.save(save_state, os.path.join(outputs_dir, latest_save_filename))

        # [新增] 步进调度器,传入验证集平均 PSNR
        scheduler.step(eval_psnr.avg)

        # [新增] 早停逻辑判断
        if eval_psnr.avg > best_psnr:
            best_psnr = eval_psnr.avg
            best_epoch = epoch
            early_stop_counter = 0  # 如果有提升,清空早停计数器
            torch.save(save_state, os.path.join(outputs_dir, best_save_filename))
            print(f"!!! 找到更好的模型,已保存 {best_save_filename} (PSNR: {best_psnr:.2f}dB)")
        else:
            early_stop_counter += 1
            print(f"--- 性能未提升,早停计数: {early_stop_counter}/{args.patience}")

            # 触发早停
            if early_stop_counter >= args.patience:
                print(f"\n!!! 连续 {args.patience} 轮验证集 PSNR 未提升,触发早停机制。训练自动结束。")
                break

    print(f'\n==> 微调彻底完成!Best Epoch: {best_epoch}, 最终最佳 PSNR: {best_psnr:.2f}dB')


if __name__ == '__main__':
    main()