医疗内窥镜超分辨率:如何从零构建HyperKvasir数据集的超分Dataloader?

3 阅读8分钟

为什么学这个

在进行图像超分辨率(Super-Resolution, SR)算法加速策略(如 ESPCN、SRLUT)的研究时,我发现常规的自然图像数据集(如 DIV2K、Set5)很难真实反映医疗场景的复杂性。为了验证算法在真实临床环境中的有效性,我将目光转向了医学图像处理领域的经典数据集 —— HyperKvasir

目前学术界公开的胃肠道内窥镜数据集屈指可数,而发布于 2020 年的 HyperKvasir 绝对是其中的“巨无霸”与试金石。它的完整版包含了惊人的 110,079 张图像374 段临床视频

对于本次的超分任务,我并没有囫囵吞枣地下载所有文件,而是精准剥离了其中经过专业医生严格筛选的 10,662 张高质量有标签图像 (Labeled Images) 。这批数据被精细划分为 23 个真实的临床类别,包括解剖标志(如 Z 线、幽门)、病理发现(如息肉、溃疡)以及医疗操作记录。

然而,医学图像 SR 与自然图像 SR 有着天壤之别。内窥镜图像普遍存在无效黑边、UI 角标、强反光点以及复杂的物理退化(如镜头失焦、暗光噪声) 。如果直接套用传统的 Resize 降采样方案,不仅会导致模型学到错误的边缘特征,还极易引发子像素级的错位(Pixel Shift),导致 PSNR/SSIM 评估失效。

因此,我花时间从零重构了一个专为内窥镜超分设计的 PyTorch DataLoader 流水线。本文将复盘这一过程中的核心技术点。

核心内容与步骤

构建这个数据管道,我主要分为以下几个核心步骤:

1. 精准提取高质量 Ground Truth (HR)

HyperKvasir 数据集极其庞大(包含视频、掩膜和海量无标签数据)。为了最大化算力效率,我只剥离了 labeled-images 目录下的一万多张高质量分类图像作为 HR(高分辨率)基准真值,摒弃了带有严重帧间压缩伪影的视频帧。

2. 内窥镜特有的预处理:动态黑边裁切

内窥镜图像四周通常有圆形的黑色遮罩或绿色的仪器信息。我通过短边比例计算(如 min_dim * 0.85)进行中心裁剪,确保网络只学习纯净的肠胃道黏膜纹理,避免模型将算力浪费在“如何超分纯黑色”上。

3. 联合裁剪与几何变换 (Joint Transformations) —— 核心逻辑

在生成 HR-LR 图像对时,采用了**“先变换,后退化”**的逆向流水线:

  • 先在大图上随机切出 HR Patch。

  • 对 HR Patch 进行随机翻转和 90 度旋转。

  • 最后,将这个已经固定下来的 HR Patch 送入退化模型生成 LR Patch。

    这种“联合操作”从物理逻辑上 100% 保证了 HR 和 LR 在空间坐标系下的绝对像素级对齐。

4. 贴近临床的物理退化模型 (Degradation Model)

为了让 LR 图像更符合真实的内窥镜物理退化,我没有使用单一的双三次插值,而是串联了多重退化逻辑:

光学模糊 (模拟组织液/失焦) -> 双三次降采样 -> 泊松噪声 (模拟暗光光子散粒) -> JPEG 压缩 (模拟医院 PACS 系统)

遇到的问题与解决方法

问题一:如何确保没有任何黑边被喂给网络?

  • 表现:简单的固定比例中心裁剪可能无法适应所有尺寸的原图,容易残留黑边。
  • 解决:编写了一个基于阈值统计的全量扫描脚本 scan_entire_dataset。将图像转为张量后,利用 tensor < 0.05 统计每个 Patch 的极暗像素占比。如果占比超过 2%,则记录该 Batch 并使用 torchvision.utils.make_grid 可视化最差的样本。通过这种静默排查,我能精准调优裁剪系数,确保数据 100% 纯净。

问题二:多进程 DataLoader 导致的随机种子失效

  • 表现:开启 num_workers > 0 后,发现不同子进程可能继承相同的随机状态,导致产出完全一致的噪声和裁剪坐标。

  • 解决:除了在主程序锁定全局种子外,显式编写了 worker_init_fn 函数:

    Python

    def seed_worker(worker_id):
        worker_seed = torch.initial_seed() % 2 ** 32
        np.random.seed(worker_seed)
        random.seed(worker_seed)
    

    将其绑定到 DataLoader 中,彻底保证了实验的绝对可复现性,这对于后续对比不同超分加速策略的微小指标差异至关重要。

问题三:Windows 环境下的多进程报错

  • 表现:Windows 的 spawn 机制会导致主脚本重复导入,引发递归报错。
  • 解决:将跨平台逻辑封装,在 Windows 下适当降低 num_workers,并严格使用 if __name__ == '__main__': 保护训练主入口。

收获与总结

“数据决定了模型的上限,而网络结构只是在逼近这个上限。”

通过这次实践,我深刻体会到在超分辨率领域,**对齐(Alignment)与退化(Degradation)**是比网络层数更关键的命题。一个粗糙的下采样函数足以毁掉一个精巧的加速算法。

这是一个非常好的补充!在技术博客的末尾放上“开箱即用”的完整源码,能极大提升文章的干货属性和读者收藏率。

我已经为您将代码中的本地物理路径(D:...)替换为了通用的占位符,以完全规避您的隐私信息。您可以直接将以下内容作为 “附录” 复制到博客的末尾:


附录:完整的高级数据管道源码(支持全量黑边扫描与多进程安全)

以下是本文探讨的完整 PyTorch DatasetDataLoader 实现代码。您可以直接保存为 dataset.py 并在本地运行测试。代码内置了全局随机数种子锁定机制以及静默黑边扫描工具。

Python

import os
import glob
import cv2
import numpy as np
import random
import time
import torch
import platform
import matplotlib.pyplot as plt
import torchvision.utils as vutils
from torch.utils.data import Dataset, DataLoader

# ==========================================
# 1. 全局随机种子控制 (多进程安全)
# ==========================================
def set_random_seed(seed=42):
    """固定所有相关的随机数种子,确保实验绝对可复现"""
    print(f"🌱 正在锁定全局随机数种子: {seed}")
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed) 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def seed_worker(worker_id):
    """确保 DataLoader 的每个子进程拥有独立的衍生种子"""
    worker_seed = torch.initial_seed() % 2 ** 32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

# ==========================================
# 2. 内窥镜专属 Dataset 定义
# ==========================================
class HyperKvasirAdvancedSRDataset(Dataset):
    def __init__(self, root_dir, scale_factor=4, hr_patch_size=192, phase='train'):
        super().__init__()
        self.scale_factor = scale_factor
        self.hr_patch_size = hr_patch_size
        self.phase = phase

        self.image_paths = glob.glob(os.path.join(root_dir, '**', '*.[jp][pn]g'), recursive=True)
        if len(self.image_paths) == 0:
            raise RuntimeError(f"未找到图像,请检查路径: {root_dir}")

    def __len__(self):
        return len(self.image_paths)

    def _rgb2y(self, img_rgb):
        """BT.601 RGB 转 Y 亮度通道"""
        img_f = img_rgb.astype(np.float32) / 255.0
        y = np.dot(img_f, [0.299, 0.587, 0.114])
        return np.expand_dims(y, axis=2)

    def _add_poisson_noise(self, img):
        """模拟暗光光子/泊松噪声"""
        img_f = img.astype(np.float32) / 255.0
        scale = random.uniform(50, 150)
        noisy = np.random.poisson(img_f * scale) / scale
        return np.clip(noisy * 255.0, 0, 255).astype(np.uint8)

    def _add_jpeg_compression(self, img):
        """模拟 PACS 系统 JPEG 压缩伪影"""
        quality = random.randint(60, 95)
        _, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality])
        return cv2.imdecode(encimg, 1)

    def __getitem__(self, idx):
        # 1. 读取并转为 RGB
        img_path = self.image_paths[idx]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        h, w = img.shape[:2]

        # 2. 预处理:粗略切除黑边及 UI
        min_dim = min(h, w)
        crop_size = int(min_dim * 0.85) # 发现黑边可适当调小此系数,如 0.75
        start_x = (w - crop_size) // 2
        start_y = (h - crop_size) // 2
        img = img[start_y:start_y + crop_size, start_x:start_x + crop_size]
        h, w = img.shape[:2]

        # 3. 联合裁剪与几何变换 (提取纯净 HR)
        if self.phase == 'train':
            rx = random.randint(0, w - self.hr_patch_size)
            ry = random.randint(0, h - self.hr_patch_size)
            hr_img = img[ry:ry + self.hr_patch_size, rx:rx + self.hr_patch_size]

            if random.random() < 0.5: hr_img = cv2.flip(hr_img, 1)
            if random.random() < 0.5: hr_img = cv2.flip(hr_img, 0)
            rot = random.randint(0, 3)
            if rot != 0: hr_img = np.rot90(hr_img, rot)
        else:
            cx, cy = (w - self.hr_patch_size) // 2, (h - self.hr_patch_size) // 2
            hr_img = img[cy:cy + self.hr_patch_size, cx:cx + self.hr_patch_size]

        # 4. 物理退化流水线 (生成绝对对齐的 LR)
        lr_img = hr_img.copy()
        if self.phase == 'train':
            sigma = random.uniform(0.1, 1.5)
            lr_img = cv2.GaussianBlur(lr_img, (0, 0), sigmaX=sigma)

        lr_h, lr_w = self.hr_patch_size // self.scale_factor, self.hr_patch_size // self.scale_factor
        lr_img = cv2.resize(lr_img, (lr_w, lr_h), interpolation=cv2.INTER_CUBIC)

        if self.phase == 'train':
            if random.random() < 0.5:
                lr_img = self._add_poisson_noise(lr_img)
            else:
                noise = np.random.normal(0, random.uniform(0, 10), lr_img.shape)
                lr_img = np.clip(lr_img.astype(np.float32) + noise, 0, 255).astype(np.uint8)
            if random.random() < 0.3:
                lr_img = self._add_jpeg_compression(lr_img)

        # 5. 提取亮度通道并转为 Tensor
        hr_y = self._rgb2y(hr_img)
        lr_y = self._rgb2y(lr_img)

        hr_tensor = torch.from_numpy(hr_y.transpose((2, 0, 1))).float()
        lr_tensor = torch.from_numpy(lr_y.transpose((2, 0, 1))).float()

        return lr_tensor, hr_tensor

# ==========================================
# 3. 跨平台 DataLoader 封装
# ==========================================
def create_HyperKvasir_dataloader(data_dir, batch_size=16, is_training=True, seed=42):
    current_os = platform.system()
    # 自适应分配 worker,规避 Windows 平台的 spawn 递归问题
    num_workers = 0 if current_os == 'Windows' and not is_training else (4 if current_os == 'Windows' else 8)

    dataset = HyperKvasirAdvancedSRDataset(
        root_dir=data_dir, scale_factor=4, hr_patch_size=192, phase='train' if is_training else 'val'
    )

    g = torch.Generator()
    g.manual_seed(seed)

    dataloader = DataLoader(
        dataset, batch_size=batch_size, shuffle=is_training,
        num_workers=num_workers, pin_memory=True, drop_last=is_training,
        persistent_workers=True if num_workers > 0 else False,
        worker_init_fn=seed_worker, generator=g  
    )
    return dataloader

# ==========================================
# 4. 全量黑边扫描与可视化诊断工具
# ==========================================
def scan_entire_dataset(dataloader, black_threshold=0.05, alarm_ratio=0.02):
    print("\n🚀 开始全量黑边扫描...")
    start_time = time.time()
    total_batches, warning_count = len(dataloader), 0
    worst_ratio, worst_batch_tensor, worst_batch_idx = 0.0, None, -1
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    for batch_idx, (lr_tensor, hr_tensor) in enumerate(dataloader):
        hr_tensor_device = hr_tensor.to(device, non_blocking=True)
        is_black = hr_tensor_device < black_threshold
        max_ratio_in_batch = is_black.float().mean(dim=(1, 2, 3)).max().item()
        
        if max_ratio_in_batch > alarm_ratio:
            warning_count += 1
            if max_ratio_in_batch > worst_ratio:
                worst_ratio = max_ratio_in_batch
                worst_batch_tensor = hr_tensor.clone()
                worst_batch_idx = batch_idx
                
        if (batch_idx + 1) % 50 == 0 or (batch_idx + 1) == total_batches:
            print(f"[{batch_idx + 1}/{total_batches}] 扫描中... 发现疑似黑边 Batch 数: {warning_count}")

    print(f"\n✅ 扫描完成!耗时: {time.time() - start_time:.2f} 秒")
    if warning_count > 0 and worst_batch_tensor is not None:
        print(f"🔥 正在可视化最严重的 Batch (最大黑边占比达: {worst_ratio * 100:.2f}%)")
        grid_img = vutils.make_grid(worst_batch_tensor, nrow=4, padding=2, normalize=True)
        plt.figure(figsize=(10, 10))
        plt.imshow(grid_img.numpy().transpose((1, 2, 0)), cmap='gray')
        plt.axis('off')
        plt.title(f"Worst Batch (Idx: {worst_batch_idx})")
        plt.show()
    else:
        print("🎉 完美!所有图像均已通过黑边检测!")

# ==========================================
# 测试入口点
# ==========================================
if __name__ == '__main__':
    SEED = 42
    set_random_seed(SEED)

    # 请替换为您本地实际的数据集路径
    data_dir = r"/path/to/your/hyperkvasir_labeled_images" 

    print("初始化扫描专用 DataLoader (关闭打乱和增强)...")
    scan_loader = create_HyperKvasir_dataloader(data_dir, batch_size=64, is_training=False, seed=SEED)
    
    # 运行物理边界扫描
    scan_entire_dataset(scan_loader)