CenterNet代码解读——数据处理部分

711 阅读1分钟

代码地址: github.com/xingyizhou/…

论文地址:arxiv.org/pdf/1904.07…

数据处理部分

以COCO数据集为例,通过get_dataset获取Dataset的信息(继承COCO和CTDetDataset类)。再根据update_dataset_info_and_set_heads 设置对应的head。

ctdet对应的head有三个:hm(heatmap),wh(边框的宽高),reg(下采样导致的位置偏移)。

初始化在COCO中,这里主要从CTDetDataset的getitem入手:

  def __getitem__(self, index):
    # 获取img和anno信息
    img_id = self.images[index]
    file_name = self.coco.loadImgs(ids=[img_id])[0]['file_name']
    img_path = os.path.join(self.img_dir, file_name)
    ann_ids = self.coco.getAnnIds(imgIds=[img_id])
    anns = self.coco.loadAnns(ids=ann_ids)
    num_objs = min(len(anns), self.max_objs)
    img = cv2.imread(img_path)
    height, width = img.shape[0], img.shape[1]
    
    # 获取中心点坐标c,以及最长边
    c = np.array([img.shape[1] / 2., img.shape[0] / 2.], dtype=np.float32)
    s = max(img.shape[0], img.shape[1]) * 1.0
    
    # input_size指输入网络的size,这里是512*512
    input_h, input_w = self.opt.input_h, self.opt.input_w
    
    flipped = False
    if self.split == 'train':
      # 指定条件下随机生成scale和center
      s = s * np.random.choice(np.arange(0.6, 1.4, 0.1))
      w_border = self._get_border(128, img.shape[1])
      h_border = self._get_border(128, img.shape[0])
      c[0] = np.random.randint(low=w_border, high=img.shape[1] - w_border)
      c[1] = np.random.randint(low=h_border, high=img.shape[0] - h_border)
    if np.random.random() < self.opt.flip:
      flipped = True
      img = img[:, ::-1, :]
      c[0] =  width - c[0] - 1
    
    # 对输入图片进行仿射变化(这里没有加旋转,主要是crop+resize为512*512)
    trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
    inp = cv2.warpAffine(img, trans_input, 
                         (input_w, input_h),
                         flags=cv2.INTER_LINEAR)
    inp = (inp.astype(np.float32) / 255.)
    if self.split == 'train' and not self.opt.no_color_aug:
      color_aug(self._data_rng, inp, self._eig_val, self._eig_vec)
    inp = (inp - self.mean) / self.std
    inp = inp.transpose(2, 0, 1) # c*h*w

    # output为128*128,是hm的spatial size
    output_h = input_h // self.opt.down_ratio  
    output_w = input_w // self.opt.down_ratio
    num_classes = self.num_classes
    trans_output = get_affine_transform(c, s, 0, [output_w, output_h])

    # GT框中心点的heatmap 类别数80*128*128
    hm = np.zeros((num_classes, output_h, output_w), dtype=np.float32)
    # GT框的宽高 128*2
    wh = np.zeros((self.max_objs, 2), dtype=np.float32)
    # 下采样4倍后取整,中心点的偏移量 128*2
    reg = np.zeros((self.max_objs, 2), dtype=np.float32)
    # 在hm(拉平)上的index 128
    ind = np.zeros((self.max_objs), dtype=np.int64)
    # GT的mask 128
    reg_mask = np.zeros((self.max_objs), dtype=np.uint8)
    # 特定位置存特定类别的wh 128*160
    cat_spec_wh = np.zeros((self.max_objs, num_classes * 2), dtype=np.float32) 
    # 特定位置存特定类别的mask 128*160
    cat_spec_mask = np.zeros((self.max_objs, num_classes * 2), dtype=np.uint8)
   
    draw_gaussian = draw_msra_gaussian if self.opt.mse_loss else \
                    draw_umich_gaussian

    gt_det = []
        for k in range(num_objs):
      ann = anns[k]
      bbox = self._coco_box_to_bbox(ann['bbox'])
      cls_id = int(self.cat_ids[ann['category_id']])
      if flipped:
        bbox[[0, 2]] = width - bbox[[2, 0]] - 1
      # 对检测框也进行仿射变换
      bbox[:2] = affine_transform(bbox[:2], trans_output)
      bbox[2:] = affine_transform(bbox[2:], trans_output)
      bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0, output_w - 1)
      bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0, output_h - 1)
      h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
      
      if h > 0 and w > 0:
        # 根据一元二次方程计算出最小半径 参考[1]
        radius = gaussian_radius((math.ceil(h), math.ceil(w)))
        radius = max(0, int(radius))
        radius = self.opt.hm_gauss if self.opt.mse_loss else radius
        
        # center坐标(float)
        ct = np.array(
          [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32) 
        # center坐标(int)
        ct_int = ct.astype(np.int32)
        
        # 得到高斯分布
        draw_gaussian(hm[cls_id], ct_int, radius)
        wh[k] = 1. * w, 1. * h 
        ind[k] = ct_int[1] * output_w + ct_int[0]
        reg[k] = ct - ct_int
        reg_mask[k] = 1 
        cat_spec_wh[k, cls_id * 2: cls_id * 2 + 2] = wh[k]
        cat_spec_mask[k, cls_id * 2: cls_id * 2 + 2] = 1
        
        # 这种方式比直接用bbox精度更高
        gt_det.append([ct[0] - w / 2, ct[1] - h / 2, 
                       ct[0] + w / 2, ct[1] + h / 2, 1, cls_id])
    
    ret = {'input': inp, 'hm': hm, 'reg_mask': reg_mask, 'ind': ind, 'wh': wh}


参考文献:

[1] cloud.tencent.com/developer/a…