CenterNet代码解读——损失函数部分

1,518 阅读2分钟

代码地址: github.com/xingyizhou/…

论文地址:arxiv.org/pdf/1904.07…

损失函数部分

首先是在CtdetTrainer(继承BaseTrainer)的init里初始化loss, 在ExdetLoss类获取loss。

class ExdetLoss(torch.nn.Module):
  def __init__(self, opt):
    super(ExdetLoss, self).__init__()
    self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
    self.crit_reg = RegL1Loss()
    self.opt = opt
    self.parts = ['t', 'l', 'b', 'r', 'c']

  def forward(self, outputs, batch):
    opt = self.opt
    hm_loss, reg_loss = 0, 0
    for s in range(opt.num_stacks):
      output = outputs[s]
      for p in self.parts:
        tag = 'hm_{}'.format(p)
        output[tag] = _sigmoid(output[tag])
        hm_loss += self.crit(output[tag], batch[tag]) / opt.num_stacks
        if p != 'c' and opt.reg_offset and opt.off_weight > 0:
          reg_loss += self.crit_reg(output['reg_{}'.format(p)], 
                                    batch['reg_mask'],
                                    batch['ind_{}'.format(p)],
                                    batch['reg_{}'.format(p)]) / opt.num_stacks
    loss = opt.hm_weight * hm_loss + opt.off_weight * reg_loss
    loss_stats = {'loss': loss, 'off_loss': reg_loss, 'hm_loss': hm_loss}
    return loss, loss_stats

假设输入图像大小为HWCH*W*C, 经过model之后,分别接三个head获取对应的特征。再加不同的损失函数进行约束:

1. hm head & focal loss

  • 预测中心点hm的损失函数,主要约束hm上中心点的概率为1,其他点的概率为0。测试时根据hm的概率得到各个中心点。

  • 其中hm head输出为H4W480\frac{H}{4}*\frac{W}{4}*80, 其中80为类别数。Yx,y,cY'_{x,y,c}代表对于类别c,中心点在(x,y)上出现的概率。这里基于focal loss进行改进:

    • GT值 Yx,y,cY_{x,y,c}是用高斯核将中心点分布到heatmap上。下图是源码生成的GT
  • Yx,y,c=1Y_{x,y,c}=1, 跟focal类似,对于easy的中心点,(1Yx,y,c)α(1-Y'_{x,y,c})^{\alpha}使loss值减少,反之,hard样本的loss增大,实现难例挖掘。

  • Yx,y,c1Y_{x,y,c}\not=1,即除了中心点之外的其它位置。LkL_k对中心点附近的临近点,训练权重进行了调整。利用(1Yx,y,c)β(1-Y'_{x,y,c})^{\beta},与中心点越靠近,Y'_{x,y,c}越接近于1,loss值越小。反之则越大。

    • 假设Yx,y,c=0.9Y_{x,y,c}=0.9,接近中心点。如果Yx,y,cY'_{x,y,c}值接近1,显然是不对的,理论上应该为0,所以用(Yx,y,c)α(Y'_{x,y,c})^\alpha惩罚;但这个位置接近中心点,值接近1也情有可原,所以用(1Yx,y,c)β(1-Y'_{x,y,c})^\beta来安慰下,减轻loss比重。
    • 假设Yx,y,c=0.1Y_{x,y,c}=0.1,远离中心点。如果Yx,y,cY'_{x,y,c}值接近1,同上需要用(Yx,y,c)α(Y'_{x,y,c})^\alpha惩罚;如果Yx,y,cY'_{x,y,c}值接近0,则是合理的,用(Yx,y,c)α(Y'_{x,y,c})^\alpha安慰。对于(1Yx,y,c)β(1-Y'_{x,y,c})^\beta,加大了距离中心较远点的损失比重。
  • 综上,(Yx,y,c)α(Y'_{x,y,c})^\alpha(1Yx,y,c)α(1-Y'_{x,y,c})^\alpha用于限制easy样本导致的gradient被dominant的问题;(1Yx,y,c)β(1-Y'_{x,y,c})^\beta用来缓解近中心点负样本惩罚力度的问题。

    def _neg_loss(pred, gt):
      ''' Modified focal loss. Exactly the same as CornerNet.
          Runs faster and costs a little bit more memory
        Arguments:
          pred (batch x c x h x w)
          gt_regr (batch x c x h x w)
      '''
      pos_inds = gt.eq(1).float()  # heatmap为1的部分是正样本
      neg_inds = gt.lt(1).float()  # 其他部分为负样本
    
      neg_weights = torch.pow(1 - gt, 4)
    
      loss = 0
    
      pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
      neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
    
      num_pos  = pos_inds.float().sum()
      pos_loss = pos_loss.sum()
      neg_loss = neg_loss.sum()
    
      if num_pos == 0:
        loss = loss - neg_loss
      else:
        loss = loss - (pos_loss + neg_loss) / num_pos
      return loss
    

2. reg head & l1 loss

  • 中心点的偏置损失。由于输入img到预测hm进行了R=4的下采样。这样hm预测的坐标映射到原图会有精度误差。因此对于每一个中心点预测精度误差带来的偏置。
  • 其中reg head输出为H4W42\frac{H}{4}*\frac{W}{4}*2,预测的偏差为$O'_{p'}。偏差的GT为[float(GT坐标下采样R倍)-int(GT坐标下采样R倍)]。 用L1 loss约束:
def _reg_loss(regr, gt_regr, mask):
  ''' L1 regression loss
    Arguments:
      regr (batch x max_objects x dim)
      gt_regr (batch x max_objects x dim)
      mask (batch x max_objects)
  '''
  num = mask.float().sum()
  mask = mask.unsqueeze(2).expand_as(gt_regr).float()

  regr = regr * mask
  gt_regr = gt_regr * mask
    
  regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, size_average=False)
  regr_loss = regr_loss / (num + 1e-4)
  return regr_loss

3. wh head & l1 loss

  • 预测检测框wh的损失。wh head输出为H4W42\frac{H}{4}*\frac{W}{4}*2。用L1 loss: