之前有师弟问我关于pytorch二分类和多分类训练测试代码的写法,这里就总结一下要点。下面为方便叙述,会将变量与形状同步给出,格式为 x(n,c,h,w)。
首先捋一下过程,我们省略模型构建步骤,剩下的步骤主要包括训练过程的loss计算和梯度更新,还有测试过程的精度计算。
- loss计算与梯度更新
给定输入为x(n,3,h,w),输出为predict(n,k,h,w),k为类别数。
- k=1
一般采用sigmoid将模型输出y映射到(0,1)范围内。
- k>=2
一般有两种实现思路。第一种:采用sigmoid,不保证概率和为1。通常用于多标签任务。第二种:采用softmax,这种保证概率和为1。
二分类任务通常可以用k=1和k=2两种方式训练。k=1采用sigmoid。k>=2采用softmax。因为分割任务没有多标签需求,因此k>=2一般都采用softmax,从而保证所有类别概率和为1,避免多个类别概率值超过阈值(这种情况常见于多标签分类)。
下面给出一个常见的多类分割模型训练过程核心代码范例。
optimizer.zero_grad()#每个epoch开始先将梯度置零
mask_true_onehot = torch.nn.functional.one_hot(targets.squeeze(1).long(), num_classes=4).permute(0, 3, 1, 2).float() #(n,c,h,w)
loss = criterion(out.contiguous(), mask_true_onehot)
loss.backward()#反向传播计算梯度
optimizer.step()#根据梯度参数更新
scheduler.step()#学习率调整
当k=1时,损失函数最常用的为BCE和DiceLoss。
BCE可以直接用torch.nn.BCELoss(),但需要先实例化
bce_loss_func=torch.nn.BCELoss()
bce_loss=bce_loss_func(predict, target)
其中,predict(n,1,h,w),target(n,1,h,w)。或者predict(n,h*w),target(n,h*w)。
DiceLoss没有pytorch官方实现,需要自己实现。这里采用VM-UNET(github.com/JCruan519/V…
class DiceLoss(nn.Module):
def __init__(self):
super(DiceLoss, self).__init__()
def forward(self, pred, target):
smooth = 1
size = pred.size(0)
pred_ = pred.view(size, -1)
target_ = target.view(size, -1)
intersection = pred_ * target_
dice_score = (2 * intersection.sum(1) + smooth)/(pred_.sum(1) + target_.sum(1) + smooth)
dice_loss = 1 - dice_score.sum()/size
return dice_loss
其中,pred(n,1,h,w),target(n,1,h,w)
这里简单说一下,dice的计算公式为:2 * |X ∩ Y| / (|X| + |Y|)。其中包括交集操作,在代码中实际上对应的是 intersection = pred_ * target_。这里为什么可以直接相乘作为交集?是因为target_∈{0,1},perd_如果二值化则pred_∈[0,1],pred_如果不进行二值化则值域为(0,1)。intersection.sum(1)就可以代表交集的总和。
这里有一个易错点,如果target_不止包含0和1,例如多分类中包括0,1,2等,这种情况下要使用dice就必须对标签进行独热编码,否则计算出来的结果会非常奇怪,因为类标签在dice计算中应该是相同权重的,如果不进行onehot,则类别2的权重就是类别1的2倍,这显然是不合理的。
k>=2时,模型经过softmax后,输出predict(n,k,h,w)。这时常用的loss函数为CE和多分类Dice。
CE函数有pytorch官方实现,为torch.nn.CrossEntropyLoss()。输入为predict(n,c,h,w),target(h,c,h,w),这里target为onehot版本。或者predict(n,c,h,w),target(n,h,w)。这里target值为标签号,例如1,2,3。另外需要注意,target数据类型为torch.long。
多类dice依旧没有官方实现,我们在上述二分类dice loss基础上实现多分类dice loss。
class DiceLoss(nn.Module):
def __init__(self):
super(DiceLoss, self).__init__()
def forward(self, pred, target):
smooth = 1
num_classes = pred.size(1)
size = pred.size(0)
dice_loss = 0
for i in range(num_classes):
pred_ = pred[:, i].view(size, -1)
target_ = target[:, i].view(size, -1)
intersection = pred_ * target_
dice_score = (2 * intersection.sum(1) + smooth) / (pred_.sum(1) + target_.sum(1) + smooth)
dice_loss += 1 - dice_score.sum() / size
return dice_loss / num_classes
其中,inputs(n,c,h,w),target(n,c,h,w)。实际上就是先将target做独热编码,然后对每一类计算dice loss,最后求平均。
- 测试阶段精度计算
分为两种,第一种是基于混淆矩阵计算得出,例如OA、precision、recall等,这部分可以将测试集所有结果累加得到总的混淆矩阵再计算,也即微平均(micro average)。
Micro Precision= ∑TP/(∑TP+∑FP)
numpy和pytorch都有函数计算混淆矩阵。基于numpy的sklearn.metrics.confusion_matrix。
import numpy as np
from sklearn.metrics import confusion_matrix
y_true = np.array([0, 1, 2, 2, 0, 1])
y_pred = np.array([0, 2, 1, 2, 0, 1])
# 计算混淆矩阵
cm = confusion_matrix(y_true, y_pred)
print(cm)
基于pytorch的可以安装torchmetrics
import torch
from torchmetrics import ConfusionMatrix
# 假设有一些真实标签和预测标签
y_true = torch.tensor([0, 1, 2, 2, 0, 1])
y_pred = torch.tensor([0, 2, 1, 2, 0, 1])
# 定义混淆矩阵计算器
num_classes = 3
confmat = ConfusionMatrix(num_classes=num_classes)
# 计算混淆矩阵
cm = confmat(y_pred, y_true)
print(cm)
第二种不基于混淆矩阵计算,如果想要直接计算测试集总的精度则需要保存所有的预测值,这在分割任务中会占用巨大的内存,也是不必要的。因此一般会分别将每张图针对每个类别计算一个结果,然后每个类别的所有结果进行平均。从而得到c(类别数)个精度值。这时候,再将所有类别进行平均,从而得到一个新的整体度量,也即宏平均(macro average)。常用的mIoU实际上就是一种宏平均指标,主要是平均了类别精度。
关于二分类和多分类分割的相关问题主要就这么多,最后再写几个常犯的错误,看看其中有没有聪明的你。
1.模型定义的最后用sigmoid,训练代码中又加了sigmoid
2.float类型tensor转换成int,结果全0,关键还不报错
3.训练与测试代码中的必要预处理步骤不同,例如训练时像素值除以255,测试时忘记
4.测试时数据增强忘记关掉,例如平移、翻转、高斯噪声
本文使用 文章同步助手 同步