Kaggle 经典比赛 HuBMAP - Hacking the Kidney 高分方案解析

198 阅读11分钟

本篇文章为比赛第一名方案解析,该比赛涉及到CV-图像分割-医疗方向。

介绍:本次竞赛的目标是实施成功且强大的肾小球FTU检测器。参赛者面临的挑战是检测不同组织制备管道中的功能性组织单位(FTU)

网址:www.kaggle.com/competition…

方法概述

  • 单一模型 Unet se_resnext101_32x4d(4 折)
  • 一些来自先前分割竞赛的技术(主要来自云竞赛)
  • 平衡瓦片采样用于训练(编辑:遮挡区域平衡)
  • 伪标签用于公开测试数据和外部数据
  • 避免边缘效应的小技巧
  • 我的流水线是在旧数据集(即数据更新之前)开发的。

一、数据准备

制作 1024x104 瓦片+偏移的 1024x1024 瓦片(我通过(512,512)偏移了瓦片)

二、验证

我选择了验证数据,以确保相同的病人编号在同一个组别中

val_patient_numbers_list = [
    [63921], # fold0
    [68250], # fold1
    [65631], # fold2
    [67177], # fold3
 ]

以下为原始数据集: image.png

注意:这里是对数据进行初步处理。根据上图我们可以得知,原始数据非常的杂乱,我们先将杂乱的数据大致分组,得到初步整理的数据。

三、平衡瓦片采样用于训练

首先,我根据掩码区域对瓦片数据进行分类(掩码瓦片的分类数=4)。然后,我采用以下平衡采样程序。

n_sample = trn_df['is_masked'].value_counts().min()
trn_df_0 = trn_df[trn_df['is_masked']==False].sample(n_sample, replace=True)
trn_df_1 = trn_df[trn_df['is_masked']==True].sample(n_sample, replace=True)
n_bin = int(trn_df_1['binned'].value_counts().mean())
trn_df_list = []
for bin_size in trn_df_1['binned'].unique():
    trn_df_list.append(trn_df_1[trn_df_1['binned']==bin_size].sample(n_bin, replace=True))
trn_df_1 = pd.concat(trn_df_list, axis=0)
trn_df_balanced = pd.concat([trn_df_1, trn_df_0], axis=0).reset_index(drop=True)

这段代码旨在对数据集进行双重平衡处理:首先平衡类别分布,然后在其中一个类别内平衡分箱分布。以下是逐步解析:

  1. 计算最小类别样本数
    n_sample = trn_df['is_masked'].value_counts().min()

    • 获取is_masked列中样本数较少的类别的数量(True或False)。
  2. 平衡类别样本

    python
    trn_df_0 = trn_df[trn_df['is_masked']==False].sample(n_sample, replace=True)
    trn_df_1 = trn_df[trn_df['is_masked']==True].sample(n_sample, replace=True)
    
    • 对两个类别分别进行重采样,使每个类别的样本数均为n_sample(允许重复抽样)。
  3. 计算分箱平均样本数
    n_bin = int(trn_df_1['binned'].value_counts().mean())

    • 在True类别(trn_df_1)中,计算各分箱样本数的平均值,用于后续平衡。
  4. 平衡分箱分布

    python
    trn_df_list = []
    for bin_size in trn_df_1['binned'].unique():
        trn_df_list.append(trn_df_1[trn_df_1['binned']==bin_size].sample(n_bin, replace=True))
    trn_df_1 = pd.concat(trn_df_list, axis=0)
    
    • 对每个分箱进行重采样,使每个分箱的样本数等于平均值n_bin,确保分箱间分布均匀。
  5. 合并平衡后的数据
    trn_df_balanced = pd.concat([trn_df_1, trn_df_0], axis=0).reset_index(drop=True)

    • 合并处理后的True类别和已平衡的False类别,形成最终平衡数据集。

目的

  • 类别平衡:解决类别不平衡问题,防止模型偏向多数类。
  • 分箱平衡:在True类别内部,确保各分箱样本数均匀,避免模型过拟合或欠拟合特定分箱。

注意事项

  • 使用replace=True可能导致重复样本,但能有效平衡数据。
  • 最终True类别的总样本数保持n_sample(因分箱数 × 平均样本数 = 原总样本数),维持类别平衡。

示例
假设原始数据中True类有100个样本,分4个箱(30,30,20,20),则:

  • n_bin = 25,处理后每个分箱25个样本,总计100个。
  • 合并False类的100个样本,最终数据集共200个样本,类别和分箱均平衡。

此方法适用于需同时处理类别和特征分布不平衡的场景,提升模型鲁棒性。

四、模型

U-Net SeResNext101 + CBAM + 深度列 + 深度监督
在我的情况下,更大的模型给出了更好的 CV 和 LB。
我已将 1024x1024 调整为 320x320 作为输入瓦片。
这里是一段代码片段:

class CenterBlock(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.conv = conv3x3(in_channel, out_channel).apply(init_weight)

    def forward(self, inputs):
        x = self.conv(inputs)
        return x

class DecodeBlock(nn.Module):
    def __init__(self, in_channel, out_channel, upsample):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(in_channel).apply(init_weight)
        self.upsample = nn.Sequential()
        if upsample:
            self.upsample.add_module('upsample',nn.Upsample(scale_factor=2, mode='nearest'))
        self.conv3x3_1 = conv3x3(in_channel, in_channel).apply(init_weight)
        self.bn2 = nn.BatchNorm2d(in_channel).apply(init_weight)
        self.conv3x3_2 = conv3x3(in_channel, out_channel).apply(init_weight)
        self.cbam = CBAM(out_channel, reduction=16)
        self.conv1x1   = conv1x1(in_channel, out_channel).apply(init_weight)

    def forward(self, inputs):
        x  = F.relu(self.bn1(inputs))
        x  = self.upsample(x)
        x  = self.conv3x3_1(x)
        x  = self.conv3x3_2(F.relu(self.bn2(x)))
        x  = self.cbam(x)
        x += self.conv1x1(self.upsample(inputs)) #shortcut
        return x

class UNET_SERESNEXT101(nn.Module):
    def __init__(self, resolution, deepsupervision, clfhead, load_weights=True):
        super().__init__()
        h,w = resolution
        self.deepsupervision = deepsupervision
        self.clfhead = clfhead

        #encoder
        model_name = 'se_resnext101_32x4d'
        seresnext101 = pretrainedmodels.__dict__[model_name](pretrained=None)
        if load_weights:
            seresnext101.load_state_dict(torch.load(f'{model_name}.pth'))

        self.encoder0 = nn.Sequential(
            seresnext101.layer0.conv1, #(*,3,h,w)->(*,64,h/2,w/2)
            seresnext101.layer0.bn1,
            seresnext101.layer0.relu1,
        )
        self.encoder1 = nn.Sequential(
            seresnext101.layer0.pool, #->(*,64,h/4,w/4)
            seresnext101.layer1 #->(*,256,h/4,w/4)
        )
        self.encoder2 = seresnext101.layer2 #->(*,512,h/8,w/8)
        self.encoder3 = seresnext101.layer3 #->(*,1024,h/16,w/16)
        self.encoder4 = seresnext101.layer4 #->(*,2048,h/32,w/32)

        #center
        self.center  = CenterBlock(2048,512) #->(*,512,h/32,w/32)

        #decoder
        self.decoder4 = DecodeBlock(512+2048,64,upsample=True) #->(*,64,h/16,w/16)
        self.decoder3 = DecodeBlock(64+1024,64, upsample=True) #->(*,64,h/8,w/8)
        self.decoder2 = DecodeBlock(64+512,64,  upsample=True) #->(*,64,h/4,w/4) 
        self.decoder1 = DecodeBlock(64+256,64,  upsample=True) #->(*,64,h/2,w/2) 
        self.decoder0 = DecodeBlock(64,64, upsample=True) #->(*,64,h,w) 

        #upsample
        self.upsample4 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True)
        self.upsample3 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
        self.upsample2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
        self.upsample1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        #deep supervision
        self.deep4 = conv1x1(64,1).apply(init_weight)
        self.deep3 = conv1x1(64,1).apply(init_weight)
        self.deep2 = conv1x1(64,1).apply(init_weight)
        self.deep1 = conv1x1(64,1).apply(init_weight)

        #final conv
        self.final_conv = nn.Sequential(
            conv3x3(320,64).apply(init_weight),
            nn.ELU(True),
            conv1x1(64,1).apply(init_weight)
        )

        #clf head
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.clf = nn.Sequential(
            nn.BatchNorm1d(2048).apply(init_weight),
            nn.Linear(2048,512).apply(init_weight),
            nn.ELU(True),
            nn.BatchNorm1d(512).apply(init_weight),
            nn.Linear(512,1).apply(init_weight)
        )

    def forward(self, inputs):
        #encoder
        x0 = self.encoder0(inputs) #->(*,64,h/2,w/2)
        x1 = self.encoder1(x0) #->(*,256,h/4,w/4)
        x2 = self.encoder2(x1) #->(*,512,h/8,w/8)
        x3 = self.encoder3(x2) #->(*,1024,h/16,w/16)
        x4 = self.encoder4(x3) #->(*,2048,h/32,w/32)

        #center
        y5 = self.center(x4) #->(*,320,h/32,w/32)

        #decoder
        y4 = self.decoder4(torch.cat([x4,y5], dim=1)) #->(*,64,h/16,w/16)
        y3 = self.decoder3(torch.cat([x3,y4], dim=1)) #->(*,64,h/8,w/8)
        y2 = self.decoder2(torch.cat([x2,y3], dim=1)) #->(*,64,h/4,w/4)
        y1 = self.decoder1(torch.cat([x1,y2], dim=1)) #->(*,64,h/2,w/2) 
        y0 = self.decoder0(y1) #->(*,64,h,w)

        #hypercolumns
        y4 = self.upsample4(y4) #->(*,64,h,w)
        y3 = self.upsample3(y3) #->(*,64,h,w)
        y2 = self.upsample2(y2) #->(*,64,h,w)
        y1 = self.upsample1(y1) #->(*,64,h,w)
        hypercol = torch.cat([y0,y1,y2,y3,y4], dim=1)

        #final conv
        logits = self.final_conv(hypercol) #->(*,1,h,w)

        #clf head
        logits_clf = self.clf(self.avgpool(x4).squeeze(-1).squeeze(-1)) #->(*,1)

        if self.clfhead:
            if self.deepsupervision:
                s4 = self.deep4(y4)
                s3 = self.deep3(y3)
                s2 = self.deep2(y2)
                s1 = self.deep1(y1)
                logits_deeps = [s4,s3,s2,s1]
                return logits, logits_deeps, logits_clf
            else:
                return logits, logits_clf
        else:
            if self.deepsupervision:
                s4 = self.deep4(y4)
                s3 = self.deep3(y3)
                s2 = self.deep2(y2)
                s1 = self.deep1(y1)
                logits_deeps = [s4,s3,s2,s1]
                return logits, logits_deeps
            else:
                return logits

以下是代码的逐层解析,主要针对U-Net结构、SE-ResNeXt101主干网络、CBAM注意力机制、深度监督和分类头的实现:


一、模块分解
1. CenterBlock
python
class CenterBlock(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.conv = conv3x3(in_channel, out_channel).apply(init_weight)
    
    def forward(self, inputs):
        return self.conv(inputs)
  • 功能:编码器与解码器间的过渡层
  • 结构:单个3x3卷积(通道数2048→512)
  • 作用:压缩高层特征,为解码器提供初始输入
2. DecodeBlock
python
class DecodeBlock(nn.Module):
    def __init__(self, in_channel, out_channel, upsample):
        # ...(详见代码)
    def forward(self, inputs):
        # ...(包含残差连接与CBAM)
  • 核心组件

    • 上采样层:Nearest插值(默认关闭,由参数upsample控制)
    • 双3x3卷积:特征提取
    • CBAM注意力:通道+空间注意力增强特征
    • 1x1卷积残差:跳跃连接适配维度
  • 特征流动

    输入 → BN+ReLU → 上采样 → conv3x3 → BN+ReLU → conv3x3 → CBAM 
    └→ 1x1卷积上采样 → 残差相加 ←
    
3. UNET_SERESNEXT101
python
class UNET_SERESNEXT101(nn.Module):
    def __init__(self, resolution, deepsupervision, clfhead, load_weights=True):
        # 初始化编码器、解码器、中心层等
    def forward(self, inputs):
        # 完整前向传播流程

二、编码器结构解析(SE-ResNeXt101)
特征金字塔构建
层级模块组成输出尺寸作用
encoder0conv1+bn1+relu1(64, H/2, W/2)初始下采样
encoder1pool + layer1(256, H/4, W/4)第一阶段特征提取
encoder2layer2(512, H/8, W/8)第二阶段特征提取
encoder3layer3(1024, H/16, W/16)第三阶段特征提取
encoder4layer4(2048, H/32, W/32)最高层语义特征

三、解码器与特征融合
1. 解码流程
解码层输入拼接输出尺寸上采样倍数
decoder4cat([x4, y5])(64, H/16, W/16)2x
decoder3cat([x3, y4])(64, H/8, W/8)2x
decoder2cat([x2, y3])(64, H/4, W/4)2x
decoder1cat([x1, y2])(64, H/2, W/2)2x
decoder0y1单独输入(64, H, W)2x
2. Hypercolumns技术
python
y4 = upsample4(y4) → (64,H,W)  
y3 = upsample3(y3) → (64,H,W)
y2 = upsample2(y2) → (64,H,W)
y1 = upsample1(y1) → (64,H,W)
hypercol = cat([y0,y1,y2,y3,y4]) → (320,H,W)
  • 作用:聚合多尺度特征,保留不同层级的语义和细节信息

四、多任务输出设计
1. 分割主输出
python
self.final_conv = nn.Sequential(
    conv3x3(320,64),  # 融合Hypercolumns
    nn.ELU(),
    conv1x1(64,1)     # 生成分割logits
)
2. 分类辅助头
python
self.clf = nn.Sequential(
    nn.BatchNorm1d(2048),
    nn.Linear(2048,512),
    nn.ELU(),
    nn.BatchNorm1d(512),
    nn.Linear(512,1)  # 生成分类logits
)
  • 输入:编码器最高层特征的全局池化(x4
  • 目的:多任务学习增强特征表达
3. 深度监督
python
self.deep4 = conv1x1(64,1)  # 监督decoder4输出
self.deep3 = conv1x1(64,1)  # 监督decoder3输出
self.deep2 = conv1x1(64,1)  # 监督decoder2输出
self.deep1 = conv1x1(64,1)  # 监督decoder1输出
  • 作用:通过中间层监督加速收敛,提升梯度传播效果

五、前向传播流程图解
输入图像
    │
    ▼
[encoder0][encoder1][encoder2][encoder3][encoder4]
                                    │          │
                                    ▼          ▼
                                 [center] → decoder4 → deep4
                                           │    │
                                           ▼    ▼
                                        decoder3 → deep3
                                           │    │
                                           ▼    ▼
                                        decoder2 → deep2
                                           │    │
                                           ▼    ▼
                                        decoder1 → deep1
                                           │    
                                           ▼    
                                        decoder0 → hypercolumns → final_conv → 分割输出
                                           ▲
                                           │
分类输出 ← clf_head ← avgpool(x4)

六、关键技术亮点
  1. 主干网络选择

    • SE-ResNeXt101提供强特征提取能力,SE模块增强通道注意力
  2. CBAM注意力机制

    • 在解码器中引入通道+空间双维度注意力,突出重要特征
  3. 深度监督策略

    • 对中间解码层输出监督,缓解梯度消失问题
  4. 多尺度特征融合

    • Hypercolumns技术整合不同层级特征,提升细节恢复能力
  5. 多任务学习框架

    • 分割与分类联合训练,共享编码器特征,提升模型泛化性

七、潜在改进方向
  1. 替换主干网络

    • 尝试EfficientNetV2或Swin Transformer等新型主干
  2. 改进注意力机制

    • 使用Triplet Attention或SimAM替代CBAM
  3. 动态上采样

    • 将双线性插值替换为可学习的转置卷积
  4. 损失函数优化

    • 对分割和分类任务采用加权损失或不确定性加权
  5. 实例归一化

    • 在解码器中尝试GroupNorm或InstanceNorm替代BatchNorm

五、损失

我使用了 bce loss + lovasz-hinge loss,在此基础上,我还使用了带有 bce loss + lovasz-hinge loss(仅针对非空掩码)的深度监督,乘以 0.1。对于分类头,我使用了 bce loss

六、外部数据和伪标签

我生成了(EDIT)训练数据、公开测试数据、hubmap-portal(portal.hubmapconsortium.org/search?enti…
; and dataset_a_dib (data.mendeley.com/datasets/k7…). 对于数据 d488c759a,我在我的最终模型中使用了 Carno Zhao 的伪标签(感谢 @carnozhao!),另一个则使用了我的模型的伪标签。我检查了我的私有分数,发现这些模型的性能并没有太大差异。但我的第一名模型是使用 Carno Zhao 的伪标签的那个。我认为提升我的分数的不是手工标注,而是间接的集成效应,因为我的模型是单一的,多样性应该会做出贡献。

七、推理技巧

我发现避免边缘效应可以持续提升 CV 和 LB。我的技巧是只使用瓦片的中心部分进行预测(我注意到@shujun717 的第三名解决方案使用了同样的想法)。我发现使用较小的部分进行预测可以更好地提升 CV 和 LB,但会耗费更多时间。所以我最终使用从原始 1024x1024 中提取的 512x512 中心部分进行预测。推理需要 x4 的时间,但使用分类头预测有助于节省推理时间。

训练代码:321Leo123/kaggle-hubmap: Kaggle HuBMAP 1st Place Solution