代码地址: github.com/xingyizhou/…
测试模块部分
核心代码在decode.py中。得到heatmap之后,由ctdet_decode函数获取最终检测框坐标和类别。
首先遍历hm检测当前pixel的值是否大于周围的八个近邻点,采用的方式是一个3x3的MaxPool,类似于anchor-based检测中nms的效果。返回的结果是筛选后的极大值点,其余不符合的位置值归为0。
def _nms(heat, kernel=3):
pad = (kernel - 1) // 2
hmax = nn.functional.max_pool2d(
heat, (kernel, kernel), stride=1, padding=pad)
keep = (hmax == heat).float()
return heat * keep
根据上一步筛选后的结果,经过top_k函数得到分数最高的100个中心点
def _topk(scores, K=40):
batch, cat, height, width = scores.size()
# 获取H*W上得分最高的K个点
topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K) # batch*cat*K
topk_inds = topk_inds % (height * width)
# K个点的x/y坐标
topk_ys = (topk_inds / width).int().float() # batch*cat*K
topk_xs = (topk_inds % width).int().float() # batch*cat*K
# 获取cat*K上得分最高的K个点
topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K) # batch*K
# K个点的类别
topk_clses = (topk_ind / K).int()
# 由topk_ind索引topk_inds/topk_ys/topk_xs对应位置的值
topk_inds = _gather_feat(
topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)
return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
def _gather_feat(feat, ind, mask=None):
dim = feat.size(2)
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
feat = feat.gather(1, ind)
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
feat = feat[mask]
feat = feat.view(-1, dim)
return feat
_topk返回了得分最高的100个框对应的score,inds(索引), clses(类别), xs/ys(中心点坐标)。利用inds索引出对应的reg和wh。再利用xs/ys计算最终bbox。整合scores, clses得到detections。
def ctdet_decode(heat, wh, reg=None, cat_spec_wh=False, K=100):
batch, cat, height, width = heat.size() # 已归一化
# heat = torch.sigmoid(heat)
# perform nms on heatmaps
heat = _nms(heat)
scores, inds, clses, ys, xs = _topk(heat, K=K)
if reg is not None:
reg = _transpose_and_gather_feat(reg, inds)
reg = reg.view(batch, K, 2)
xs = xs.view(batch, K, 1) + reg[:, :, 0:1]
ys = ys.view(batch, K, 1) + reg[:, :, 1:2]
else:
xs = xs.view(batch, K, 1) + 0.5
ys = ys.view(batch, K, 1) + 0.5
wh = _transpose_and_gather_feat(wh, inds)
if cat_spec_wh:
wh = wh.view(batch, K, cat, 2)
clses_ind = clses.view(batch, K, 1, 1).expand(batch, K, 1, 2).long()
wh = wh.gather(2, clses_ind).view(batch, K, 2)
else:
wh = wh.view(batch, K, 2)
clses = clses.view(batch, K, 1).float()
scores = scores.view(batch, K, 1)
bboxes = torch.cat([xs - wh[..., 0:1] / 2,
ys - wh[..., 1:2] / 2,
xs + wh[..., 0:1] / 2,
ys + wh[..., 1:2] / 2], dim=2)
detections = torch.cat([bboxes, scores, clses], dim=2)
return detections