代码地址: github.com/xingyizhou/…
损失函数部分
在trainer.py的GenericLoss类里初始化和计算loss。
In [1]: outputs[0]['hm'].shape
Out[1]: torch.Size([10, 1, 136, 240])
In [2]: outputs[0]['reg'].shape
Out[2]: torch.Size([10, 2, 136, 240])
In [3]: outputs[0]['wh'].shape
Out[3]: torch.Size([10, 2, 136, 240])
In [4]: outputs[0]['tracking'].shape
Out[4]: torch.Size([10, 2, 136, 240])
In [5]: outputs[0]['ltrb_amodal'].shape
Out[5]: torch.Size([10, 4, 136, 240])
类似于centernet, hm调用了FastFocalLoss, 比原来版本更高效。
def _only_neg_loss(pred, gt):
gt = torch.pow(1 - gt, 4)
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * gt
return neg_loss.sum()
class FastFocalLoss(nn.Module):
'''
Reimplemented focal loss, exactly the same as the CornerNet version.
Faster and costs much less memory.
'''
def __init__(self, opt=None):
super(FastFocalLoss, self).__init__()
self.only_neg_loss = _only_neg_loss
def forward(self, out, target, ind, mask, cat):
'''
Arguments:
out, target: B x C x H x W
ind, mask: B x M
cat (category id for peaks): B x M
'''
neg_loss = self.only_neg_loss(out, target)
pos_pred_pix = _tranpose_and_gather_feat(out, ind) # B x M x C
# 行人只有一类,故pos_pred和pos_pred_pix相同
pos_pred = pos_pred_pix.gather(2, cat.unsqueeze(2)) # B x M
num_pos = mask.sum()
pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, 2) * \
mask.unsqueeze(2)
pos_loss = pos_loss.sum()
if num_pos == 0:
return - neg_loss
return - (pos_loss + neg_loss) / num_pos
def _tranpose_and_gather_feat(feat, ind):
# feat n*1*136*240
# ind n*256
feat = feat.permute(0, 2, 3, 1).contiguous() # n*136*240*1
feat = feat.view(feat.size(0), -1, feat.size(3)) # n*32640*1
feat = _gather_feat(feat, ind)
return feat
def _gather_feat(feat, ind):
dim = feat.size(2) # 1
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) # n*256*1
feat = feat.gather(1, ind) # n*256*1 得到gt中心点在hm上对应位置的值(由于ind中还有大量0,计算loss时需要mask出gt)
return feat
剩余的 reg/wh/tracking/ltrb_amodal head都采用l1 loss
class RegWeightedL1Loss(nn.Module):
def __init__(self):
super(RegWeightedL1Loss, self).__init__()
def forward(self, output, mask, ind, target):
pred = _tranpose_and_gather_feat(output, ind)
# loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
loss = F.l1_loss(pred * mask, target * mask, reduction='sum')
loss = loss / (mask.sum() + 1e-4)
return loss