COCO评估输出指定某类AP或者输出每个类别AP结果_coco api打印出所有类别的ap值

268 阅读2分钟

if name == 'main': ap = argparse.ArgumentParser() ap.add_argument('--gt-json', type=str, default='instances.json', help='coco val2017 annotations json files') ap.add_argument('--pred-json', type=str, default='coco_results.json', help='pred coco val2017 annotations json files') args = ap.parse_args() print(args)

pred_json_path = args.pred_json

MAX_IMAGES = 1000
coco_gt = COCO(args.gt_json)
image_ids = coco_gt.getImgIds()[:MAX_IMAGES]

eval(coco_gt, image_ids, pred_json_path)

以上代码评估后的输出


![](https://p9-xtjj-sign.byteimg.com/tos-cn-i-73owjymdk6/5dd935659ecb41e883094dc251f21e72~tplv-73owjymdk6-jj-mark-v1:0:0:0:0:5o6Y6YeR5oqA5pyv56S-5Yy6IEAg55So5oi3MDgwNDUxMTkwMTI=:q75.awebp?rk3s=f64ab15b&x-expires=1770904096&x-signature=E%2F8kecLfdT2%2FjdaSNm6KM0%2F5Sto%3D)


 


修改后的代码,可以指定输出某类AP值,只修改eval函数:



def eval(coco_gt, image_ids, pred_json_path): # load results in COCO evaluation tool coco_pred = coco_gt.loadRes(pred_json_path)

# run COCO evaluation
print('BBox')
coco_eval = COCOeval(coco_gt, coco_pred, 'bbox')
coco_eval.params.imgIds = image_ids
coco_eval.params.catIds = [2] # 你可以根据需要增减类别
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()

修改后输出,比如只输出第2类


![](https://p9-xtjj-sign.byteimg.com/tos-cn-i-73owjymdk6/e484b3a2b9a5418b8f654240de201e1e~tplv-73owjymdk6-jj-mark-v1:0:0:0:0:5o6Y6YeR5oqA5pyv56S-5Yy6IEAg55So5oi3MDgwNDUxMTkwMTI=:q75.awebp?rk3s=f64ab15b&x-expires=1770904096&x-signature=u4Y%2Fwbd6m197%2BaA%2FRKxt9In6CMo%3D)


 


二 输出每个类别AP结果(需要修改pycocotools下的coco.py和cocoeval.py)


首先修改coco.py的类COCO的初始化为,在84行下添加代码



def __init__(self, annotation_file=None):
    """
    Constructor of Microsoft COCO helper class for reading and visualizing annotations.
    :param annotation_file (str): location of annotation file
    :param image_folder (str): location to the folder that hosts images.
    :return:
    """
    # load dataset
    self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict()
    self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
    if not annotation_file == None:
        print('loading annotations into memory...')
        tic = time.time()
        with open(annotation_file, 'r') as f:
            dataset = json.load(f)
        assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
        print('Done (t={:0.2f}s)'.format(time.time()- tic))
        print(
            "category names: {}".format([e["name"] for e in sorted(dataset["categories"], key=lambda x: x["id"])]))
        self.dataset = dataset
        self.createIndex()

修改cocoeval.py,在第456行下添加代码,修改summarize函数



def summarize(self): ''' Compute and display summary metrics for evaluation results. Note this functin can only be applied on the default parameter setting ''' def _summarize( ap=1, iouThr=None, areaRng='all', maxDets=100 ): p = self.params iStr = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}' titleStr = 'Average Precision' if ap == 1 else 'Average Recall' typeStr = '(AP)' if ap==1 else '(AR)' iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1])
if iouThr is None else '{:0.2f}'.format(iouThr)

        aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
        mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
        if ap == 1:
            # dimension of precision: [TxRxKxAxM]
            s = self.eval['precision']
            # IoU
            if iouThr is not None:
                t = np.where(iouThr == p.iouThrs)[0]
                s = s[t]
            s = s[:,:,:,aind,mind]
        else:
            # dimension of recall: [TxKxAxM]
            s = self.eval['recall']
            if iouThr is not None:
                t = np.where(iouThr == p.iouThrs)[0]
                s = s[t]
            s = s[:,:,aind,mind]
        if len(s[s>-1])==0:
            mean_s = -1
        else:
            mean_s = np.mean(s[s>-1])
        #print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s))
        category_dimension = 1 + int(ap)
        if s.shape[category_dimension] > 1:

            iStr += ", per category = {}"
            mean_axis = (0,)