TL;DR: 训练 Loss可以近似,但评估指标要务必TL;DR: 训练 Loss可以近似,但评估指标要务必
最近在搭建自己的训练框架时发现:Dice loss 和 Dice评估指标 虽然都是基于Dice系数计算,但是使用的时候由于身处计算环境和目的不一样,会有一些区别。 区别主要在于:
- 几乎所有的
Loss都是以目标最小化为目的,所以DiceLoss最后会有一个 最小化的步骤Dice Loss从原理上讲,求dice需要先将类别概率转换为类别标签,然后求交求并。但由于标签转换时会涉及argmax操作,而argmax操作会导致计算后的loss失去梯度信息(argmax不可导)。所以为了不丢失梯度信息,DiceLoss使用的是原始的预测概率分布。 个人认为DiceLoss求得是一个近似值,其实也可以理解,作为引导训练的指标,只要方向对即可,不需要很精确。Dice系数,作为一个评估模型训练性能的指标,需要务必精确,因此,在计算Dice系数时,需要先将预测结果的原始预测概率先通过argmax转换为预测标签,然后再求交求并求结果。
Ref:
discuss.pytorch.org/t/torch-arg…
Dice Loss
Dice Loss是一种常见的用于图像分割任务的, 基于Dice系数的度量集合相似度的损失函数,特别是对于类别不平衡的数据时。
计算公式如下:
Dice Loss = 1- \frac {2 \times \vert X \cap Y \vert} {\vert X \vert + \vert Y \vert}
假如预测标签维度为[batch, 4, D, W, H], 真实标签维度为 [batch, D, W, H] ,其计算代码如下:
class DiceLoss(nn.Module):
"""Dice loss for image segmentation"""
def __init__(self, smooth=1e-6):
"""
初始化函数:
:param smooth: 平滑变量,防止分母为0
"""
super(DiceLoss, self).__init__()
self.smooth = smooth
def forward(self, y_pred, y_true):
"""
前向反馈
:param y_pred: 预测值 [batch, 4, D, W, H]
:param y_true: 真实值 [batch, D, W, H]
"""
# y_pred = torch.argmax(y_pred, dim=1).to(dtype=torch.int64) # ⚠️argmax会使得loss失去梯度信息
# y_pred = F.one_hot(y_pred, num_classes=4).permute(0, 4, 1, 2, 3).float()
y_true = F.one_hot(y_true, num_classes=4).permute(0, 4, 1, 2, 3).float()
# 计算Dice 系数
intersection = (y_pred * y_true).sum(dim=(2, 3, 4))
union = y_pred.sum(dim=(2, 3, 4)) + y_true.sum(dim=(2, 3, 4))
Diceloss = (2 * intersection + self.smooth) / (union + self.smooth)
return 1 - Diceloss.mean()
Dice系数评估指标
Dice 系数是两个集合相似度的一种衡量指标,范围从0到1。
计算公式:
Dice = \frac {2 \times \vert X \cap Y \vert} {\vert X \vert + \vert Y \vert}
Dice系数越高吗,表示标签结果与真实标签越接近。
假如预测标签维度为[batch, 4, D, W, H], 真实标签维度为 [batch, D, W, H] ,其计算代码如下:
def dice_coefficient(self, y_pred, y_true):
"""
计算Dice 系数
:param y_pred: 预测标签
:param y_true: 真实标签
:return: Dice 系数
"""
# 预处理
y_pred = torch.argmax(y_pred, dim=1).to(dtype=torch.int64) # 降维,选出概率最大的类索引值,即获取预测类别标签
y_pred = F.one_hot(y_pred, num_classes=4).permute(0, 4, 1, 2, 3).float() # one-hot
y_true = F.one_hot(y_true, num_classes=4).permute(0, 4, 1, 2, 3).float() # ont-hot
# 计算Dice 系数
intersection = (y_pred * y_true).sum(dim=(2, 3, 4))
union = y_pred.sum(dim=(2, 3, 4)) + y_true.sum(dim=(2, 3, 4))
dice = (2*(intersection + self.smooth)) / (union + self.smooth)
return dice.mean(dim=0)