CenterNet代码解读——测试模块部分

656 阅读2分钟

代码地址: github.com/xingyizhou/…

论文地址:arxiv.org/pdf/1904.07…

测试模块部分

核心代码在decode.py中。得到heatmap之后,由ctdet_decode函数获取最终检测框坐标和类别。

首先遍历hm检测当前pixel的值是否大于周围的八个近邻点,采用的方式是一个3x3的MaxPool,类似于anchor-based检测中nms的效果。返回的结果是筛选后的极大值点,其余不符合的位置值归为0。

def _nms(heat, kernel=3):
    pad = (kernel - 1) // 2
    hmax = nn.functional.max_pool2d(
        heat, (kernel, kernel), stride=1, padding=pad)
    keep = (hmax == heat).float()
    return heat * keep

根据上一步筛选后的结果,经过top_k函数得到分数最高的100个中心点

def _topk(scores, K=40):
    batch, cat, height, width = scores.size()
    
    # 获取H*W上得分最高的K个点
    topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K) # batch*cat*K
    topk_inds = topk_inds % (height * width)
    
    # K个点的x/y坐标
    topk_ys   = (topk_inds / width).int().float() # batch*cat*K
    topk_xs   = (topk_inds % width).int().float() # batch*cat*K
      
    # 获取cat*K上得分最高的K个点
    topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K) # batch*K
    
    # K个点的类别
    topk_clses = (topk_ind / K).int()
    
    # 由topk_ind索引topk_inds/topk_ys/topk_xs对应位置的值
    topk_inds = _gather_feat(
        topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
    topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
    topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)

    return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
def _gather_feat(feat, ind, mask=None):
    dim  = feat.size(2)
    ind  = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
    feat = feat.gather(1, ind)
    if mask is not None:
        mask = mask.unsqueeze(2).expand_as(feat)
        feat = feat[mask]
        feat = feat.view(-1, dim)
    return feat

_topk返回了得分最高的100个框对应的score,inds(索引), clses(类别), xs/ys(中心点坐标)。利用inds索引出对应的reg和wh。再利用xs/ys计算最终bbox。整合scores, clses得到detections。

def ctdet_decode(heat, wh, reg=None, cat_spec_wh=False, K=100):
    batch, cat, height, width = heat.size() # 已归一化

    # heat = torch.sigmoid(heat)
    # perform nms on heatmaps
    heat = _nms(heat)
      
    scores, inds, clses, ys, xs = _topk(heat, K=K)
    if reg is not None:
      reg = _transpose_and_gather_feat(reg, inds)
      reg = reg.view(batch, K, 2)
      xs = xs.view(batch, K, 1) + reg[:, :, 0:1]
      ys = ys.view(batch, K, 1) + reg[:, :, 1:2]
    else:
      xs = xs.view(batch, K, 1) + 0.5
      ys = ys.view(batch, K, 1) + 0.5
    wh = _transpose_and_gather_feat(wh, inds)
    if cat_spec_wh:
      wh = wh.view(batch, K, cat, 2)
      clses_ind = clses.view(batch, K, 1, 1).expand(batch, K, 1, 2).long()
      wh = wh.gather(2, clses_ind).view(batch, K, 2)
    else:
      wh = wh.view(batch, K, 2)
    clses  = clses.view(batch, K, 1).float()
    scores = scores.view(batch, K, 1)
    bboxes = torch.cat([xs - wh[..., 0:1] / 2, 
                        ys - wh[..., 1:2] / 2,
                        xs + wh[..., 0:1] / 2, 
                        ys + wh[..., 1:2] / 2], dim=2)
    detections = torch.cat([bboxes, scores, clses], dim=2)
      
    return detections