代码地址: github.com/xingyizhou/…
损失函数部分
首先是在CtdetTrainer(继承BaseTrainer)的init里初始化loss, 在ExdetLoss类获取loss。
class ExdetLoss(torch.nn.Module):
def __init__(self, opt):
super(ExdetLoss, self).__init__()
self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
self.crit_reg = RegL1Loss()
self.opt = opt
self.parts = ['t', 'l', 'b', 'r', 'c']
def forward(self, outputs, batch):
opt = self.opt
hm_loss, reg_loss = 0, 0
for s in range(opt.num_stacks):
output = outputs[s]
for p in self.parts:
tag = 'hm_{}'.format(p)
output[tag] = _sigmoid(output[tag])
hm_loss += self.crit(output[tag], batch[tag]) / opt.num_stacks
if p != 'c' and opt.reg_offset and opt.off_weight > 0:
reg_loss += self.crit_reg(output['reg_{}'.format(p)],
batch['reg_mask'],
batch['ind_{}'.format(p)],
batch['reg_{}'.format(p)]) / opt.num_stacks
loss = opt.hm_weight * hm_loss + opt.off_weight * reg_loss
loss_stats = {'loss': loss, 'off_loss': reg_loss, 'hm_loss': hm_loss}
return loss, loss_stats
假设输入图像大小为, 经过model之后,分别接三个head获取对应的特征。再加不同的损失函数进行约束:
1. hm head & focal loss
-
预测中心点hm的损失函数,主要约束hm上中心点的概率为1,其他点的概率为0。测试时根据hm的概率得到各个中心点。
-
其中hm head输出为, 其中80为类别数。代表对于类别c,中心点在(x,y)上出现的概率。这里基于focal loss进行改进:
- GT值 是用高斯核将中心点分布到heatmap上。下图是源码生成的GT
- GT值 是用高斯核将中心点分布到heatmap上。下图是源码生成的GT
-
当, 跟focal类似,对于easy的中心点,使loss值减少,反之,hard样本的loss增大,实现难例挖掘。
-
当,即除了中心点之外的其它位置。对中心点附近的临近点,训练权重进行了调整。利用,与中心点越靠近,Y'_{x,y,c}越接近于1,loss值越小。反之则越大。
- 假设,接近中心点。如果值接近1,显然是不对的,理论上应该为0,所以用惩罚;但这个位置接近中心点,值接近1也情有可原,所以用来安慰下,减轻loss比重。
- 假设,远离中心点。如果值接近1,同上需要用惩罚;如果值接近0,则是合理的,用安慰。对于,加大了距离中心较远点的损失比重。
-
综上,和用于限制easy样本导致的gradient被dominant的问题;用来缓解近中心点负样本惩罚力度的问题。
def _neg_loss(pred, gt): ''' Modified focal loss. Exactly the same as CornerNet. Runs faster and costs a little bit more memory Arguments: pred (batch x c x h x w) gt_regr (batch x c x h x w) ''' pos_inds = gt.eq(1).float() # heatmap为1的部分是正样本 neg_inds = gt.lt(1).float() # 其他部分为负样本 neg_weights = torch.pow(1 - gt, 4) loss = 0 pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds num_pos = pos_inds.float().sum() pos_loss = pos_loss.sum() neg_loss = neg_loss.sum() if num_pos == 0: loss = loss - neg_loss else: loss = loss - (pos_loss + neg_loss) / num_pos return loss
2. reg head & l1 loss
- 中心点的偏置损失。由于输入img到预测hm进行了R=4的下采样。这样hm预测的坐标映射到原图会有精度误差。因此对于每一个中心点预测精度误差带来的偏置。
- 其中reg head输出为,预测的偏差为$O'_{p'}。偏差的GT为[float(GT坐标下采样R倍)-int(GT坐标下采样R倍)]。 用L1 loss约束:
def _reg_loss(regr, gt_regr, mask):
''' L1 regression loss
Arguments:
regr (batch x max_objects x dim)
gt_regr (batch x max_objects x dim)
mask (batch x max_objects)
'''
num = mask.float().sum()
mask = mask.unsqueeze(2).expand_as(gt_regr).float()
regr = regr * mask
gt_regr = gt_regr * mask
regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, size_average=False)
regr_loss = regr_loss / (num + 1e-4)
return regr_loss
3. wh head & l1 loss
- 预测检测框wh的损失。wh head输出为。用L1 loss: