前言
在医疗影像的超分辨率(Super-Resolution, SR)研究中,高质量的开源数据一直是稀缺资源。近期由香港中文大学、直觉外科公司等机构联合推出的 SurgiSR4K 数据集打破了这一僵局。作为首个针对机器人辅助微创手术的原生 4K (3840×2160) 视频数据集,它不仅提供了完美的亚像素级对齐,还涵盖了反光、遮挡、烟雾等极具挑战性的真实手术场景。
然而,面对 4K 分辨率的庞大张量,传统的 PyTorch 数据加载方式会瞬间导致显存溢出(OOM)。本文将带你从零构建一个专为超分辨率任务定制的、支持**联合随机裁剪(Joint Random Crop)**的高效 DataLoader,帮你平滑开启 4K 医疗影像的算法训练。
核心知识点
在编写超分辨率的 DataLoader 时,我们必须跨越以下三个技术门槛:
-
嵌套路径与标签匹配
SurgiSR4K 采用了按“分辨率”和“视频ID/器械复杂度”嵌套的文件夹结构(如
480x270p/vid_001_480x270p_1tool/)。我们需要在代码中动态解析相对路径,将输入图(LR)精准映射到 4K 标签图(HR)。 -
联合随机裁剪 (Joint Random Crop) —— 突破显存瓶颈
整张 4K 图片无法直接喂入网络。我们需要在 LR 图像上随机切取小块(如
64×64),并严格根据放大倍率(Scale Factor)在 HR 图像上切取对应的大块(如 8 倍超分下的512×512)。这不仅解决了显存问题,还是一种极佳的数据扩增手段。 -
联合几何变换 (Joint Geometric Transformations)
超分任务中的数据增强必须是“绑定”的。LR 做了怎样的水平/垂直翻转或旋转,HR 必须做一模一样的动作,否则会导致模型学出重影。
步骤与核心代码
为了实现上述功能,我们不使用易造成随机状态不一致的常规 transforms,而是引入 torchvision.transforms.functional 进行底层的精细控制。
以下是完整且可直接用于工程化训练的 dataset.py 核心代码:
Python
import random
from pathlib import Path
from PIL import Image
import torchvision.transforms.functional as TF
import torch
from torch.utils.data import Dataset, DataLoader
class SurgiSR4KDataset(Dataset):
"""
支持联合随机裁剪与同步数据增强的 SurgiSR4K Dataset
"""
def __init__(self, data_root, lr_res="480x270p", hr_res="3840x2160p",
scale_factor=8, lr_patch_size=64, is_train=True):
self.data_root = Path(data_root)
self.lr_res = lr_res
self.hr_res = hr_res
self.scale_factor = scale_factor
self.lr_patch_size = lr_patch_size
self.is_train = is_train
self.lr_dir = self.data_root / self.lr_res
self.hr_dir = self.data_root / self.hr_res
# 使用 rglob 递归搜索所有子文件夹中的 png 图片
self.lr_image_paths = sorted(list(self.lr_dir.rglob("*.png")))
if not self.lr_image_paths:
raise ValueError(f"在 {self.lr_dir} 及其子文件夹中未找到图像!")
def __len__(self):
return len(self.lr_image_paths)
def __getitem__(self, idx):
# 1. 获取 LR 路径并推导 HR 绝对路径
lr_path = self.lr_image_paths[idx]
rel_path = lr_path.relative_to(self.lr_dir) # 提取包含子文件夹的相对路径
hr_rel_path_str = str(rel_path).replace(self.lr_res, self.hr_res)
hr_path = self.hr_dir / hr_rel_path_str
# 2. 读取图像
lr_img = Image.open(lr_path).convert("RGB")
hr_img = Image.open(hr_path).convert("RGB")
if self.is_train:
# =====================================
# 3. 联合随机裁剪 (Joint Random Crop)
# =====================================
lr_w, lr_h = lr_img.size
# 在 LR 图上随机生成不越界的左上角坐标
lr_x = random.randint(0, lr_w - self.lr_patch_size)
lr_y = random.randint(0, lr_h - self.lr_patch_size)
# 严格映射到 HR 图像的坐标和尺寸
hr_x = lr_x * self.scale_factor
hr_y = lr_y * self.scale_factor
hr_patch_size = self.lr_patch_size * self.scale_factor
# 同步执行裁剪
lr_img = lr_img.crop((lr_x, lr_y, lr_x + self.lr_patch_size, lr_y + self.lr_patch_size))
hr_img = hr_img.crop((hr_x, hr_y, hr_x + hr_patch_size, hr_y + hr_patch_size))
# =====================================
# 4. 联合数据增强 (Joint Augmentation)
# =====================================
# 随机水平翻转
if random.random() < 0.5:
lr_img = TF.hflip(lr_img)
hr_img = TF.hflip(hr_img)
# 随机垂直翻转
if random.random() < 0.5:
lr_img = TF.vflip(lr_img)
hr_img = TF.vflip(hr_img)
# 随机旋转 (0, 90, 180, 270 度)
angle = random.choice([0, 90, 180, 270])
if angle != 0:
lr_img = TF.rotate(lr_img, angle)
hr_img = TF.rotate(hr_img, angle)
# 5. 转换为 Tensor 并归一化到 [0, 1]
lr_tensor = TF.to_tensor(lr_img)
hr_tensor = TF.to_tensor(hr_img)
return {"lr": lr_tensor, "hr": hr_tensor}
# 辅助包装函数
def create_surgisr4k_dataloader(data_root, batch_size=4, num_workers=4):
dataset = SurgiSR4KDataset(
data_root=data_root,
lr_res="480x270p", hr_res="3840x2160p",
scale_factor=8, lr_patch_size=64, is_train=True
)
return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
注意事项与避坑指南
- 禁用破坏性增强:在图像分类中常用的
Color Jitter(色彩抖动)、Random Gaussian Blur(高斯模糊)等增强手段,严禁在 SR 的 HR 标签上使用。SR 模型的任务是学习精确的像素映射,改变标签颜色或模糊度会破坏 Ground Truth。 - 多进程读取陷阱:如果在 Windows 系统下调试代码遇到 DataLoader 报错,请将
num_workers临时设为0。在 Linux 服务器上进行正式训练时,建议将其设为4或8并配合pin_memory=True,以最大化 GPU 吞吐量。 - 路径转义问题:Windows 系统下填写
data_root路径时,务必在字符串前加r(如r"D:\datasets..."),防止\n或\t被错误转义。
总结
构建一个稳健的数据流水线是训练深度学习模型的第一步。针对 SurgiSR4K 这样的原生 4K 数据集,普通的 DataLoader 是无法胜任的。通过递归路径匹配、联合随机裁剪以及底层函数级的数据增强同步,我们成功解决了显存溢出与标签对齐的难题。
将上述代码保存到你的工程中,调整好 Batch Size,你的超分网络现在就可以畅快地吸入这些高质量的 4K 医疗影像数据了。
你的下一个超分辨率模型准备用哪种网络架构(SwinIR / HAT / RCAN)呢?欢迎在评论区交流!