携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第10天,点击查看活动详情
首先我们要弄清一些问题,因为对于目标检测其输出要远远比分类复杂,对于边界框进行回归,我们都知道是回归中心点的坐标,以及编辑框的宽高。好,那么模型就是输出中心点坐标和预测编辑框的宽高吗? 还是输出相对于宽高这些值比率,显然都不是,
这张图我们需要弄清清楚楚,明明白白,不然随后在编码或者看源码就会有问题,首先模型输出的是什么,模型实际输出值是 这些也叫做回归参数,如何用这些值恢复到在图像上真实的中心点坐标和边界框的宽高呢。
首先我们需要理解网格概念,首先将真实图像坐标需要映射到一个网格上,
class YoloLoss(nn.Module):
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
#
self.bce = nn.BCEWithLogitsLoss()
#
self.entropy = nn.CrossEntropyLoss()
self.sigmoid = nn.Sigmoid()
self.lambda_class = 1
self.lambda_noobj = 10
self.lambda_obj = 1
self.lambda_box = 10
置信度损失函数
obj = target[...,0] == 1
noobj = target[...,0] == 0
no_object_loss = self.bce((preds[...,0:1][noobj]),(target[...,0:1][noobj]))
BCEWithLogitsLoss 损失函数将 Sigmoid 函数和 BCELoss 函数作为一个整体来使用,这样做的好处就是比 Sigmoid 和 BCELoss 分别使用,在数值更稳定,因为他们合并到一个逻辑单元,就可以利用了 log 和 exp 技巧来实现数值稳定。
通过下面一段代码来解释一下obj = target[...,0] == 1
和 preds[...,0:1][noobj]
来筛选没有目标的预测框
a = torch.tensor([[[0,1,1],[1,1,2]],[[0,2,1],[0,2,2]]])
a
这里创建一个 tensor a 维度为
tensor([[[0, 1, 1], [1, 1, 2]], [[0, 2, 1], [0, 2, 2]]])
obj = a[...,0] == 1
noobj = a[...,0] == 0
obj
这里 obj
为一个 tensor([[False, True], [False, False]])
,并且 obj
维度为 torch.Size([2, 2])
可以用作过滤来将满足条件,也就是对应位置为 True 数据保留
a[...,0:1][obj]
box_preds = torch.cat([self.sigmoid(preds[...,1:3]),torch.exp(preds[...,3:5]) * anchors],dim=-1)
有目标损失值
#Object Loss
anchors = anchors.reshape(1,3,1,1,2) # 通过 broadcasting 让 anchors 和
#
box_preds = torch.cat([self.sigmoid(preds[...,1:3]),torch.exp(preds[...,3:5]) * anchors],dim=-1)
ious = intersection_over_union(box_preds[obj],target[...,1:5][obj]).detach()
object_loss = self.bce((preds[...,0:1][obj]), (ious * target[...,0:1]))
为了计算 IoU 需要将预测边界框输出中心点和宽高的回归参数进行处理,然后和目标边界框中心点和宽高做 IoU 对于目标损失函数将 IoU 考虑进去。
计算定位损失
preds[...,1:3] = self.sigmoid(preds[...,1:3])
target[...,3:5] = torch.log(1e-16 + target[...,3:5]/anchors)
box_loss = self.mse(preds[...,1:5][obj],target[...,1:5][obj])
类别损失
#Class Loss
class_loss = self.entropy(
(preds[...,5:][obj]),(target[...,5][obj].long())
)
return (
self.lambda_box * box_loss
+ self.lambda_obj * object_loss
+ self.lambda_noobj * no_object_loss
+ self.lambda_class * class_loss
)