自己写—YOLOv3(3)—损失函数

1,600 阅读2分钟

携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第10天,点击查看活动详情

首先我们要弄清一些问题,因为对于目标检测其输出要远远比分类复杂,对于边界框进行回归,我们都知道是回归中心点的坐标,以及编辑框的宽高。好,那么模型就是输出中心点坐标和预测编辑框的宽高吗? 还是输出相对于宽高这些值比率,显然都不是,

006.png

这张图我们需要弄清清楚楚,明明白白,不然随后在编码或者看源码就会有问题,首先模型输出的是什么,模型实际输出值是 tx,ty,tw,tht_x,t_y,t_w,t_h 这些也叫做回归参数,如何用这些值恢复到在图像上真实的中心点坐标和边界框的宽高呢。

首先我们需要理解网格概念,首先将真实图像坐标需要映射到一个网格上,

class YoloLoss(nn.Module):
    def __init__(self):
        super().__init__()
        
self.mse = nn.MSELoss()
#
self.bce = nn.BCEWithLogitsLoss()
#
self.entropy = nn.CrossEntropyLoss()
self.sigmoid = nn.Sigmoid()
self.lambda_class = 1
self.lambda_noobj = 10
self.lambda_obj = 1
self.lambda_box = 10

置信度损失函数

obj = target[...,0] == 1
noobj = target[...,0] == 0
no_object_loss = self.bce((preds[...,0:1][noobj]),(target[...,0:1][noobj]))

BCEWithLogitsLoss 损失函数将 Sigmoid 函数和 BCELoss 函数作为一个整体来使用,这样做的好处就是比 Sigmoid 和 BCELoss 分别使用,在数值更稳定,因为他们合并到一个逻辑单元,就可以利用了 log 和 exp 技巧来实现数值稳定。

l(x,y)=L={l1,,lN}Tln=wn[ynlogσ(xn)+(1yn)log(1σ(xn))]l(x,y) = L = \{l_1,\cdots,l_N \}^T\\ l_n = - w_n \left[ y_n \log\sigma(x_n) + (1-y_n)\log(1 - \sigma(x_n)) \right]

通过下面一段代码来解释一下obj = target[...,0] == 1preds[...,0:1][noobj] 来筛选没有目标的预测框

a = torch.tensor([[[0,1,1],[1,1,2]],[[0,2,1],[0,2,2]]])
a

这里创建一个 tensor a 维度为 2×2×32\times 2 \times 3

tensor([[[0, 1, 1], [1, 1, 2]], [[0, 2, 1], [0, 2, 2]]])
obj = a[...,0] == 1
noobj = a[...,0] == 0
obj

这里 obj 为一个 tensor([[False, True], [False, False]]) ,并且 obj 维度为 torch.Size([2, 2]) 可以用作过滤来将满足条件,也就是对应位置为 True 数据保留

a[...,0:1][obj]
box_preds = torch.cat([self.sigmoid(preds[...,1:3]),torch.exp(preds[...,3:5]) * anchors],dim=-1)

有目标损失值

#Object Loss
anchors = anchors.reshape(1,3,1,1,2) # 通过 broadcasting 让 anchors 和
# 
box_preds = torch.cat([self.sigmoid(preds[...,1:3]),torch.exp(preds[...,3:5]) * anchors],dim=-1)
ious = intersection_over_union(box_preds[obj],target[...,1:5][obj]).detach()

object_loss = self.bce((preds[...,0:1][obj]), (ious * target[...,0:1]))

为了计算 IoU 需要将预测边界框输出中心点和宽高的回归参数进行处理,然后和目标边界框中心点和宽高做 IoU 对于目标损失函数将 IoU 考虑进去。

计算定位损失


preds[...,1:3] = self.sigmoid(preds[...,1:3])
target[...,3:5] = torch.log(1e-16 + target[...,3:5]/anchors)

box_loss = self.mse(preds[...,1:5][obj],target[...,1:5][obj])

类别损失

#Class Loss
class_loss = self.entropy(
    (preds[...,5:][obj]),(target[...,5][obj].long())
)
return (
    self.lambda_box * box_loss
    + self.lambda_obj * object_loss
    + self.lambda_noobj * no_object_loss
    + self.lambda_class * class_loss
)