本篇文章为比赛第一名方案解析,该比赛涉及到CV-图像分割-医疗方向。
介绍:本次竞赛的目标是实施成功且强大的肾小球FTU检测器。参赛者面临的挑战是检测不同组织制备管道中的功能性组织单位(FTU)
网址:www.kaggle.com/competition…
方法概述
- 单一模型 Unet se_resnext101_32x4d(4 折)
- 一些来自先前分割竞赛的技术(主要来自云竞赛)
- 平衡瓦片采样用于训练(编辑:遮挡区域平衡)
- 伪标签用于公开测试数据和外部数据
- 避免边缘效应的小技巧
- 我的流水线是在旧数据集(即数据更新之前)开发的。
一、数据准备
制作 1024x104 瓦片+偏移的 1024x1024 瓦片(我通过(512,512)偏移了瓦片)
二、验证
我选择了验证数据,以确保相同的病人编号在同一个组别中
val_patient_numbers_list = [
[63921], # fold0
[68250], # fold1
[65631], # fold2
[67177], # fold3
]
以下为原始数据集:
注意:这里是对数据进行初步处理。根据上图我们可以得知,原始数据非常的杂乱,我们先将杂乱的数据大致分组,得到初步整理的数据。
三、平衡瓦片采样用于训练
首先,我根据掩码区域对瓦片数据进行分类(掩码瓦片的分类数=4)。然后,我采用以下平衡采样程序。
n_sample = trn_df['is_masked'].value_counts().min()
trn_df_0 = trn_df[trn_df['is_masked']==False].sample(n_sample, replace=True)
trn_df_1 = trn_df[trn_df['is_masked']==True].sample(n_sample, replace=True)
n_bin = int(trn_df_1['binned'].value_counts().mean())
trn_df_list = []
for bin_size in trn_df_1['binned'].unique():
trn_df_list.append(trn_df_1[trn_df_1['binned']==bin_size].sample(n_bin, replace=True))
trn_df_1 = pd.concat(trn_df_list, axis=0)
trn_df_balanced = pd.concat([trn_df_1, trn_df_0], axis=0).reset_index(drop=True)
这段代码旨在对数据集进行双重平衡处理:首先平衡类别分布,然后在其中一个类别内平衡分箱分布。以下是逐步解析:
-
计算最小类别样本数
n_sample = trn_df['is_masked'].value_counts().min()- 获取
is_masked列中样本数较少的类别的数量(True或False)。
- 获取
-
平衡类别样本
python trn_df_0 = trn_df[trn_df['is_masked']==False].sample(n_sample, replace=True) trn_df_1 = trn_df[trn_df['is_masked']==True].sample(n_sample, replace=True)- 对两个类别分别进行重采样,使每个类别的样本数均为
n_sample(允许重复抽样)。
- 对两个类别分别进行重采样,使每个类别的样本数均为
-
计算分箱平均样本数
n_bin = int(trn_df_1['binned'].value_counts().mean())- 在True类别(
trn_df_1)中,计算各分箱样本数的平均值,用于后续平衡。
- 在True类别(
-
平衡分箱分布
python trn_df_list = [] for bin_size in trn_df_1['binned'].unique(): trn_df_list.append(trn_df_1[trn_df_1['binned']==bin_size].sample(n_bin, replace=True)) trn_df_1 = pd.concat(trn_df_list, axis=0)- 对每个分箱进行重采样,使每个分箱的样本数等于平均值
n_bin,确保分箱间分布均匀。
- 对每个分箱进行重采样,使每个分箱的样本数等于平均值
-
合并平衡后的数据
trn_df_balanced = pd.concat([trn_df_1, trn_df_0], axis=0).reset_index(drop=True)- 合并处理后的True类别和已平衡的False类别,形成最终平衡数据集。
目的:
- 类别平衡:解决类别不平衡问题,防止模型偏向多数类。
- 分箱平衡:在True类别内部,确保各分箱样本数均匀,避免模型过拟合或欠拟合特定分箱。
注意事项:
- 使用
replace=True可能导致重复样本,但能有效平衡数据。 - 最终True类别的总样本数保持
n_sample(因分箱数 × 平均样本数 = 原总样本数),维持类别平衡。
示例:
假设原始数据中True类有100个样本,分4个箱(30,30,20,20),则:
n_bin = 25,处理后每个分箱25个样本,总计100个。- 合并False类的100个样本,最终数据集共200个样本,类别和分箱均平衡。
此方法适用于需同时处理类别和特征分布不平衡的场景,提升模型鲁棒性。
四、模型
U-Net SeResNext101 + CBAM + 深度列 + 深度监督
在我的情况下,更大的模型给出了更好的 CV 和 LB。
我已将 1024x1024 调整为 320x320 作为输入瓦片。
这里是一段代码片段:
class CenterBlock(nn.Module):
def __init__(self, in_channel, out_channel):
super().__init__()
self.conv = conv3x3(in_channel, out_channel).apply(init_weight)
def forward(self, inputs):
x = self.conv(inputs)
return x
class DecodeBlock(nn.Module):
def __init__(self, in_channel, out_channel, upsample):
super().__init__()
self.bn1 = nn.BatchNorm2d(in_channel).apply(init_weight)
self.upsample = nn.Sequential()
if upsample:
self.upsample.add_module('upsample',nn.Upsample(scale_factor=2, mode='nearest'))
self.conv3x3_1 = conv3x3(in_channel, in_channel).apply(init_weight)
self.bn2 = nn.BatchNorm2d(in_channel).apply(init_weight)
self.conv3x3_2 = conv3x3(in_channel, out_channel).apply(init_weight)
self.cbam = CBAM(out_channel, reduction=16)
self.conv1x1 = conv1x1(in_channel, out_channel).apply(init_weight)
def forward(self, inputs):
x = F.relu(self.bn1(inputs))
x = self.upsample(x)
x = self.conv3x3_1(x)
x = self.conv3x3_2(F.relu(self.bn2(x)))
x = self.cbam(x)
x += self.conv1x1(self.upsample(inputs)) #shortcut
return x
class UNET_SERESNEXT101(nn.Module):
def __init__(self, resolution, deepsupervision, clfhead, load_weights=True):
super().__init__()
h,w = resolution
self.deepsupervision = deepsupervision
self.clfhead = clfhead
#encoder
model_name = 'se_resnext101_32x4d'
seresnext101 = pretrainedmodels.__dict__[model_name](pretrained=None)
if load_weights:
seresnext101.load_state_dict(torch.load(f'{model_name}.pth'))
self.encoder0 = nn.Sequential(
seresnext101.layer0.conv1, #(*,3,h,w)->(*,64,h/2,w/2)
seresnext101.layer0.bn1,
seresnext101.layer0.relu1,
)
self.encoder1 = nn.Sequential(
seresnext101.layer0.pool, #->(*,64,h/4,w/4)
seresnext101.layer1 #->(*,256,h/4,w/4)
)
self.encoder2 = seresnext101.layer2 #->(*,512,h/8,w/8)
self.encoder3 = seresnext101.layer3 #->(*,1024,h/16,w/16)
self.encoder4 = seresnext101.layer4 #->(*,2048,h/32,w/32)
#center
self.center = CenterBlock(2048,512) #->(*,512,h/32,w/32)
#decoder
self.decoder4 = DecodeBlock(512+2048,64,upsample=True) #->(*,64,h/16,w/16)
self.decoder3 = DecodeBlock(64+1024,64, upsample=True) #->(*,64,h/8,w/8)
self.decoder2 = DecodeBlock(64+512,64, upsample=True) #->(*,64,h/4,w/4)
self.decoder1 = DecodeBlock(64+256,64, upsample=True) #->(*,64,h/2,w/2)
self.decoder0 = DecodeBlock(64,64, upsample=True) #->(*,64,h,w)
#upsample
self.upsample4 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True)
self.upsample3 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
self.upsample2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
self.upsample1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
#deep supervision
self.deep4 = conv1x1(64,1).apply(init_weight)
self.deep3 = conv1x1(64,1).apply(init_weight)
self.deep2 = conv1x1(64,1).apply(init_weight)
self.deep1 = conv1x1(64,1).apply(init_weight)
#final conv
self.final_conv = nn.Sequential(
conv3x3(320,64).apply(init_weight),
nn.ELU(True),
conv1x1(64,1).apply(init_weight)
)
#clf head
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.clf = nn.Sequential(
nn.BatchNorm1d(2048).apply(init_weight),
nn.Linear(2048,512).apply(init_weight),
nn.ELU(True),
nn.BatchNorm1d(512).apply(init_weight),
nn.Linear(512,1).apply(init_weight)
)
def forward(self, inputs):
#encoder
x0 = self.encoder0(inputs) #->(*,64,h/2,w/2)
x1 = self.encoder1(x0) #->(*,256,h/4,w/4)
x2 = self.encoder2(x1) #->(*,512,h/8,w/8)
x3 = self.encoder3(x2) #->(*,1024,h/16,w/16)
x4 = self.encoder4(x3) #->(*,2048,h/32,w/32)
#center
y5 = self.center(x4) #->(*,320,h/32,w/32)
#decoder
y4 = self.decoder4(torch.cat([x4,y5], dim=1)) #->(*,64,h/16,w/16)
y3 = self.decoder3(torch.cat([x3,y4], dim=1)) #->(*,64,h/8,w/8)
y2 = self.decoder2(torch.cat([x2,y3], dim=1)) #->(*,64,h/4,w/4)
y1 = self.decoder1(torch.cat([x1,y2], dim=1)) #->(*,64,h/2,w/2)
y0 = self.decoder0(y1) #->(*,64,h,w)
#hypercolumns
y4 = self.upsample4(y4) #->(*,64,h,w)
y3 = self.upsample3(y3) #->(*,64,h,w)
y2 = self.upsample2(y2) #->(*,64,h,w)
y1 = self.upsample1(y1) #->(*,64,h,w)
hypercol = torch.cat([y0,y1,y2,y3,y4], dim=1)
#final conv
logits = self.final_conv(hypercol) #->(*,1,h,w)
#clf head
logits_clf = self.clf(self.avgpool(x4).squeeze(-1).squeeze(-1)) #->(*,1)
if self.clfhead:
if self.deepsupervision:
s4 = self.deep4(y4)
s3 = self.deep3(y3)
s2 = self.deep2(y2)
s1 = self.deep1(y1)
logits_deeps = [s4,s3,s2,s1]
return logits, logits_deeps, logits_clf
else:
return logits, logits_clf
else:
if self.deepsupervision:
s4 = self.deep4(y4)
s3 = self.deep3(y3)
s2 = self.deep2(y2)
s1 = self.deep1(y1)
logits_deeps = [s4,s3,s2,s1]
return logits, logits_deeps
else:
return logits
以下是代码的逐层解析,主要针对U-Net结构、SE-ResNeXt101主干网络、CBAM注意力机制、深度监督和分类头的实现:
一、模块分解
1. CenterBlock
python
class CenterBlock(nn.Module):
def __init__(self, in_channel, out_channel):
super().__init__()
self.conv = conv3x3(in_channel, out_channel).apply(init_weight)
def forward(self, inputs):
return self.conv(inputs)
- 功能:编码器与解码器间的过渡层
- 结构:单个3x3卷积(通道数2048→512)
- 作用:压缩高层特征,为解码器提供初始输入
2. DecodeBlock
python
class DecodeBlock(nn.Module):
def __init__(self, in_channel, out_channel, upsample):
# ...(详见代码)
def forward(self, inputs):
# ...(包含残差连接与CBAM)
-
核心组件:
- 上采样层:Nearest插值(默认关闭,由参数
upsample控制) - 双3x3卷积:特征提取
- CBAM注意力:通道+空间注意力增强特征
- 1x1卷积残差:跳跃连接适配维度
- 上采样层:Nearest插值(默认关闭,由参数
-
特征流动:
输入 → BN+ReLU → 上采样 → conv3x3 → BN+ReLU → conv3x3 → CBAM └→ 1x1卷积上采样 → 残差相加 ←
3. UNET_SERESNEXT101
python
class UNET_SERESNEXT101(nn.Module):
def __init__(self, resolution, deepsupervision, clfhead, load_weights=True):
# 初始化编码器、解码器、中心层等
def forward(self, inputs):
# 完整前向传播流程
二、编码器结构解析(SE-ResNeXt101)
特征金字塔构建
| 层级 | 模块组成 | 输出尺寸 | 作用 |
|---|---|---|---|
encoder0 | conv1+bn1+relu1 | (64, H/2, W/2) | 初始下采样 |
encoder1 | pool + layer1 | (256, H/4, W/4) | 第一阶段特征提取 |
encoder2 | layer2 | (512, H/8, W/8) | 第二阶段特征提取 |
encoder3 | layer3 | (1024, H/16, W/16) | 第三阶段特征提取 |
encoder4 | layer4 | (2048, H/32, W/32) | 最高层语义特征 |
三、解码器与特征融合
1. 解码流程
| 解码层 | 输入拼接 | 输出尺寸 | 上采样倍数 |
|---|---|---|---|
decoder4 | cat([x4, y5]) | (64, H/16, W/16) | 2x |
decoder3 | cat([x3, y4]) | (64, H/8, W/8) | 2x |
decoder2 | cat([x2, y3]) | (64, H/4, W/4) | 2x |
decoder1 | cat([x1, y2]) | (64, H/2, W/2) | 2x |
decoder0 | y1单独输入 | (64, H, W) | 2x |
2. Hypercolumns技术
python
y4 = upsample4(y4) → (64,H,W)
y3 = upsample3(y3) → (64,H,W)
y2 = upsample2(y2) → (64,H,W)
y1 = upsample1(y1) → (64,H,W)
hypercol = cat([y0,y1,y2,y3,y4]) → (320,H,W)
- 作用:聚合多尺度特征,保留不同层级的语义和细节信息
四、多任务输出设计
1. 分割主输出
python
self.final_conv = nn.Sequential(
conv3x3(320,64), # 融合Hypercolumns
nn.ELU(),
conv1x1(64,1) # 生成分割logits
)
2. 分类辅助头
python
self.clf = nn.Sequential(
nn.BatchNorm1d(2048),
nn.Linear(2048,512),
nn.ELU(),
nn.BatchNorm1d(512),
nn.Linear(512,1) # 生成分类logits
)
- 输入:编码器最高层特征的全局池化(
x4) - 目的:多任务学习增强特征表达
3. 深度监督
python
self.deep4 = conv1x1(64,1) # 监督decoder4输出
self.deep3 = conv1x1(64,1) # 监督decoder3输出
self.deep2 = conv1x1(64,1) # 监督decoder2输出
self.deep1 = conv1x1(64,1) # 监督decoder1输出
- 作用:通过中间层监督加速收敛,提升梯度传播效果
五、前向传播流程图解
输入图像
│
▼
[encoder0]→[encoder1]→[encoder2]→[encoder3]→[encoder4]
│ │
▼ ▼
[center] → decoder4 → deep4
│ │
▼ ▼
decoder3 → deep3
│ │
▼ ▼
decoder2 → deep2
│ │
▼ ▼
decoder1 → deep1
│
▼
decoder0 → hypercolumns → final_conv → 分割输出
▲
│
分类输出 ← clf_head ← avgpool(x4)
六、关键技术亮点
-
主干网络选择
- SE-ResNeXt101提供强特征提取能力,SE模块增强通道注意力
-
CBAM注意力机制
- 在解码器中引入通道+空间双维度注意力,突出重要特征
-
深度监督策略
- 对中间解码层输出监督,缓解梯度消失问题
-
多尺度特征融合
- Hypercolumns技术整合不同层级特征,提升细节恢复能力
-
多任务学习框架
- 分割与分类联合训练,共享编码器特征,提升模型泛化性
七、潜在改进方向
-
替换主干网络
- 尝试EfficientNetV2或Swin Transformer等新型主干
-
改进注意力机制
- 使用Triplet Attention或SimAM替代CBAM
-
动态上采样
- 将双线性插值替换为可学习的转置卷积
-
损失函数优化
- 对分割和分类任务采用加权损失或不确定性加权
-
实例归一化
- 在解码器中尝试GroupNorm或InstanceNorm替代BatchNorm
五、损失
我使用了 bce loss + lovasz-hinge loss,在此基础上,我还使用了带有 bce loss + lovasz-hinge loss(仅针对非空掩码)的深度监督,乘以 0.1。对于分类头,我使用了 bce loss
六、外部数据和伪标签
我生成了(EDIT)训练数据、公开测试数据、hubmap-portal(portal.hubmapconsortium.org/search?enti…
; and dataset_a_dib (data.mendeley.com/datasets/k7…). 对于数据 d488c759a,我在我的最终模型中使用了 Carno Zhao 的伪标签(感谢 @carnozhao!),另一个则使用了我的模型的伪标签。我检查了我的私有分数,发现这些模型的性能并没有太大差异。但我的第一名模型是使用 Carno Zhao 的伪标签的那个。我认为提升我的分数的不是手工标注,而是间接的集成效应,因为我的模型是单一的,多样性应该会做出贡献。
七、推理技巧
我发现避免边缘效应可以持续提升 CV 和 LB。我的技巧是只使用瓦片的中心部分进行预测(我注意到@shujun717 的第三名解决方案使用了同样的想法)。我发现使用较小的部分进行预测可以更好地提升 CV 和 LB,但会耗费更多时间。所以我最终使用从原始 1024x1024 中提取的 512x512 中心部分进行预测。推理需要 x4 的时间,但使用分类头预测有助于节省推理时间。
训练代码:321Leo123/kaggle-hubmap: Kaggle HuBMAP 1st Place Solution