[NMS系列] Soft-NMS

82 阅读2分钟

论文: arxiv.org/pdf/1704.04…

1: 研究背景

Screen Shot 2023-08-24 at 8.08.00 PM.png

经典的NMS使用阈值的方式对iou过大的框进行抑制,也就是当IOU高于阈值时,相应框的分数会被直接置0。这种方式简单粗暴,可能会出现目标框漏检的情况。

Screen Shot 2023-08-24 at 8.26.20 PM.png

为了解决这个问题,softNMSIOU高于阈值的时候,并没有将相应框的分数直接置0,而是通过IOU对相应框的分数进行惩罚,惩罚方式如下: Screen Shot 2023-08-24 at 8.18.39 PM.png 然而这种方式基于分段函数,还是有点硬,因此作者又提出了一种高斯函数的方式对框的分数进行评估,评估方式如: Screen Shot 2023-08-24 at 8.18.51 PM.png

2: 源码

def nms(dets, , sigma=0.5, nms_threshold=0.7, soft_threshold=0.1, method=1):
"""
dets:数据维度为[n,5],n表示总共有n条数据,5表示:score,x1,y1,x2,y2;
thresh:IOU阈值
"""
    if dets.shape[0] == 0:
        return dets[[], :]
    scores = dets[:, 0]
    x1 = dets[:, 1]
    y1 = dets[:, 2]
    x2 = dets[:, 3]
    y2 = dets[:, 4]
    #areas上有n个区域的面积
    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    #order返回的是score的逆序排序的索引
    order = scores.argsort()[::-1]
    #获取总共有多少个框
    ndets = dets.shape[0]
    #标记框是否保留
    suppressed = np.zeros((ndets), dtype=np.int)
    #遍历所有的框
    for _i in range(ndets):
        #获取分数排在_i的是第几个框
        i = order[_i]
        #标记为舍去的框没必要再考虑了
        if suppressed[i] == 1:
            continue
        #获取第i个框的坐标,i是相对于原始的dets
        ix1 = x1[i]
        iy1 = y1[i]
        ix2 = x2[i]
        iy2 = y2[i]
        #获取第i个框的面积
        iarea = areas[i]
        #与剩下所有的框求IOU
        for _j in range(_i + 1, ndets):
            j = order[_j]
            if suppressed[j] == 1:
                continue
            xx1 = max(ix1, x1[j])
            yy1 = max(iy1, y1[j])
            xx2 = min(ix2, x2[j])
            yy2 = min(iy2, y2[j])
            score_j = scores[j]
            w = max(0.0, xx2 - xx1 + 1)
            h = max(0.0, yy2 - yy1 + 1)
            inter = w * h
            #求取IOU
            iou= inter / (iarea + areas[j] - inter)
            if method == 1:  # 线性更新分数
                if iou > nms_threshold:
                    weight = 1 - iou
                else:
                    weight = 1
            elif method == 2:  # 高斯权重
                weight = np.exp(-(iou * iou) / sigma)
            else:  # 传统 NMS
                if iou > nms_threshold:
                    weight = 0
                else:
                    weight = 1

             score_j  = weight * score_j   # 根据iou来更新分数

              # 如果box分数太低
             if  score_j < soft_threshold:
                    suppressed[j] = 1
    #获取保留框的索引
    keep = np.where(suppressed == 0)[0]
    #提取保留的框
    dets = dets[keep, :]
    return dets

3: 参考文献:

zhuanlan.zhihu.com/p/566802565