CenterTrack代码解读——损失函数部分

871 阅读1分钟

代码地址: github.com/xingyizhou/…

论文地址:arxiv.org/pdf/2004.01…

损失函数部分

在trainer.py的GenericLoss类里初始化和计算loss。

In [1]: outputs[0]['hm'].shape
Out[1]: torch.Size([10, 1, 136, 240])

In [2]: outputs[0]['reg'].shape
Out[2]: torch.Size([10, 2, 136, 240])

In [3]: outputs[0]['wh'].shape
Out[3]: torch.Size([10, 2, 136, 240])

In [4]: outputs[0]['tracking'].shape
Out[4]: torch.Size([10, 2, 136, 240])

In [5]: outputs[0]['ltrb_amodal'].shape
Out[5]: torch.Size([10, 4, 136, 240])

类似于centernet, hm调用了FastFocalLoss, 比原来版本更高效。

def _only_neg_loss(pred, gt):
  gt = torch.pow(1 - gt, 4)
  neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * gt
  return neg_loss.sum()

class FastFocalLoss(nn.Module):
  '''
  Reimplemented focal loss, exactly the same as the CornerNet version.
  Faster and costs much less memory.
  '''
  def __init__(self, opt=None):
    super(FastFocalLoss, self).__init__()
    self.only_neg_loss = _only_neg_loss

  def forward(self, out, target, ind, mask, cat):
    '''
    Arguments:
      out, target: B x C x H x W
      ind, mask: B x M
      cat (category id for peaks): B x M
    '''
    neg_loss = self.only_neg_loss(out, target)
    pos_pred_pix = _tranpose_and_gather_feat(out, ind) # B x M x C
    
    # 行人只有一类,故pos_pred和pos_pred_pix相同
    pos_pred = pos_pred_pix.gather(2, cat.unsqueeze(2)) # B x M
    num_pos = mask.sum()
    pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, 2) * \
               mask.unsqueeze(2)
    pos_loss = pos_loss.sum()
    if num_pos == 0:
      return - neg_loss
    return - (pos_loss + neg_loss) / num_pos
    

def _tranpose_and_gather_feat(feat, ind): 
  # feat n*1*136*240 
  # ind n*256
  feat = feat.permute(0, 2, 3, 1).contiguous() # n*136*240*1
  feat = feat.view(feat.size(0), -1, feat.size(3)) # n*32640*1
  feat = _gather_feat(feat, ind)
  return feat
  
def _gather_feat(feat, ind):
  dim = feat.size(2) # 1
  ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) # n*256*1
  feat = feat.gather(1, ind)  # n*256*1 得到gt中心点在hm上对应位置的值(由于ind中还有大量0,计算loss时需要mask出gt)
  return feat

剩余的 reg/wh/tracking/ltrb_amodal head都采用l1 loss

class RegWeightedL1Loss(nn.Module):
  def __init__(self):
    super(RegWeightedL1Loss, self).__init__()
  
  def forward(self, output, mask, ind, target):
    pred = _tranpose_and_gather_feat(output, ind)
    # loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
    loss = F.l1_loss(pred * mask, target * mask, reduction='sum')
    loss = loss / (mask.sum() + 1e-4)
    return loss