CenterTrack代码解读——测试模块部分

795 阅读6分钟

代码地址: github.com/xingyizhou/…

论文地址:arxiv.org/pdf/2004.01…

测试模块部分

首先浏览下prefetch_test函数,核心在detector.run中,后面会细讲。

def prefetch_test(opt):

  Dataset = dataset_factory[opt.test_dataset]
  opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
  print(opt)
  Logger(opt)
  
  split = 'val' if not opt.trainval else 'test'
  
  # 初始化mot类
  dataset = Dataset(opt, split)
  
  # 初始化detector类
  detector = Detector(opt)
  
  load_results = {}

  data_loader = torch.utils.data.DataLoader(
    PrefetchDataset(opt, dataset, detector.pre_process), 
    batch_size=1, shuffle=False, num_workers=1, pin_memory=True)

  results = {}
  num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
  bar = Bar('{}'.format(opt.exp_id), max=num_iters)
  time_stats = ['tot', 'load', 'pre', 'net', 'dec', 'post', 'merge', 'track']
  avg_time_stats = {t: AverageMeter() for t in time_stats}
  
  # false
  if opt.use_loaded_results:
    for img_id in data_loader.dataset.images:
      results[img_id] = load_results['{}'.format(img_id)]
    num_iters = 0
    
  for ind, (img_id, pre_processed_images) in enumerate(data_loader):
    if ind >= num_iters:
      break
      
    # 如果是第一帧,可以load_results作为前一帧。一般是没有load_results,直接把前一帧置0。
    if opt.tracking and ('is_first_frame' in pre_processed_images):
      if '{}'.format(int(img_id.numpy().astype(np.int32)[0])) in load_results:
        pre_processed_images['meta']['pre_dets'] = \
          load_results['{}'.format(int(img_id.numpy().astype(np.int32)[0]))]
      else:
        print()
        print('No pre_dets for', int(img_id.numpy().astype(np.int32)[0]), 
          '. Use empty initialization.')
        pre_processed_images['meta']['pre_dets'] = []   
      # 初始化tracker
      detector.reset_tracking()
      print('Start tracking video', int(pre_processed_images['video_id']))
     
    # 核心部分,通过输入图片得到tracking结果
    ret = detector.run(pre_processed_images)
    # 根据img_id存储结果
    results[int(img_id.numpy().astype(np.int32)[0])] = ret['results']
    
    Bar.suffix = '[{0}/{1}]|Tot: {total:} |ETA: {eta:} '.format(
                   ind, num_iters, total=bar.elapsed_td, eta=bar.eta_td)
    for t in avg_time_stats:
      avg_time_stats[t].update(ret[t])
      Bar.suffix = Bar.suffix + '|{} {tm.val:.3f}s ({tm.avg:.3f}s) '.format(
        t, tm = avg_time_stats[t])
    if opt.print_iter > 0:
      if ind % opt.print_iter == 0:
        print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix))
    else:
      bar.next()
  bar.finish()
  if opt.save_results:
    print('saving results to', opt.save_dir + '/save_results_{}{}.json'.format(
      opt.test_dataset, opt.dataset_version))
    json.dump(_to_list(copy.deepcopy(results)), 
              open(opt.save_dir + '/save_results_{}{}.json'.format(
                opt.test_dataset, opt.dataset_version), 'w'))
  # 计算度量指标
  dataset.run_eval(results, opt.save_dir)

在prefetch_test函数中,首先初始化了dataset, 即PrefetchDataset(opt, dataset, detector.pre_process)部分。

因为PrefetchDataset内重新定义了getitem,所以原来在GenericDataset中getitem处理输入数据的部分在detector.pre_process进行(主要是crop_resize_normalize)。

先来看PrefetchDataset:

class PrefetchDataset(torch.utils.data.Dataset):
  def __init__(self, opt, dataset, pre_process_func):
    self.images = dataset.images
    self.load_image_func = dataset.coco.loadImgs
    self.img_dir = dataset.img_dir
    self.pre_process_func = pre_process_func
    self.get_default_calib = dataset.get_default_calib
    self.opt = opt
  
  def __getitem__(self, index):
    img_id = self.images[index]
    img_info = self.load_image_func(ids=[img_id])[0]
    img_path = os.path.join(self.img_dir, img_info['file_name'])
    image = cv2.imread(img_path)
    images, meta = {}, {}
    for scale in opt.test_scales:
      input_meta = {}
      calib = img_info['calib'] if 'calib' in img_info \
        else self.get_default_calib(image.shape[1], image.shape[0])
      input_meta['calib'] = calib
      images[scale], meta[scale] = self.pre_process_func(
        image, scale, input_meta)
    ret = {'images': images, 'image': image, 'meta': meta}
    if 'frame_id' in img_info and img_info['frame_id'] == 1:
      ret['is_first_frame'] = 1
      ret['video_id'] = img_info['video_id']
    return img_id, ret

  def __len__(self):
    return len(self.images)

主要是对输入数据进行了pre_process,并且支持多scale。以index=0为例,看下输出。如果不是第一帧,ret则不包括'is_first_frame'和'video_id'。

In [1]: ret.keys()
Out[1]: dict_keys(['images', 'image', 'meta', 'is_first_frame', 'video_id'])

# 多个scale处理后的image
In [2]: ret['images'].keys()
Out[2]: dict_keys([1.0])
In [3]: ret['images'][1.0].shape
Out[3]: torch.Size([1, 3, 544, 960])

# 原始image
In [4]: ret['image'].shape
Out[4]: (1080, 1920, 3)

# 一些数据处理的参数
In [5]: ret['meta']
Out[5]: 
{1.0: {'calib': array([[1.2e+03, 0.0e+00, 9.6e+02, 0.0e+00],
         [0.0e+00, 1.2e+03, 5.4e+02, 0.0e+00],
         [0.0e+00, 0.0e+00, 1.0e+00, 0.0e+00]], dtype=float32),
  'c': array([960., 540.], dtype=float32),
  's': 1920.0,
  'height': 1080,
  'width': 1920,
  'out_height': 136,
  'out_width': 240,
  'inp_height': 544,
  'inp_width': 960,
  'trans_input': array([[ 0.5, -0. ,  0. ],
         [ 0. ,  0.5,  2. ]]),
  'trans_output': array([[ 0.125, -0.   ,  0.   ],
         [ 0.   ,  0.125,  0.5  ]])}}
         
In [6]: ret['is_first_frame']
Out[6]: 1

In [7]: ret['video_id']
Out[7]: 1

在prefetch_test函数的dataloader中,每一张输入图片,都会首先进行detector.reset_tracking(),对tracker进行初始化。

# detector.py
  def reset_tracking(self):
    self.tracker.reset()
    self.pre_images = None
    self.pre_image_ori = None
# tracker.py
  def reset(self):
    self.id_count = 0
    self.tracks = []

整个detector.run(pre_processed_images)得到ret,下面拆解来看。

# detectot.py
  def run(self, image_or_path_or_tensor, meta={}):
    load_time, pre_time, net_time, dec_time, post_time = 0, 0, 0, 0, 0
    merge_time, track_time, tot_time, display_time = 0, 0, 0, 0
    self.debugger.clear()
    start_time = time.time()

    # 获取原始图片 1080*1920*3
    image = image_or_path_or_tensor['image'][0].numpy() 
    pre_processed_images = image_or_path_or_tensor
    pre_processed = True
    
    loaded_time = time.time()
    load_time += (loaded_time - start_time)
    
    detections = []

    # 支持multi-scale测试
    for scale in self.opt.test_scales:
      scale_start_time = time.time()

      # prefetch testing
      # 获取处理后图片 1*3*544*960
      images = pre_processed_images['images'][scale][0]
      meta = pre_processed_images['meta'][scale]
      meta = {k: v.numpy()[0] for k, v in meta.items()}
      
      # 只对于首帧图片为true, 初始化后为[]
      if 'pre_dets' in pre_processed_images['meta']:
        meta['pre_dets'] = pre_processed_images['meta']['pre_dets']
      # false
      if 'cur_dets' in pre_processed_images['meta']:
        meta['cur_dets'] = pre_processed_images['meta']['cur_dets']
      
      images = images.to(self.opt.device, non_blocking=self.opt.non_block_test)

      # initializing tracker
      pre_hms, pre_inds = None, None
      if self.opt.tracking:
      
        # 初始化第一帧
        if self.pre_images is None:
          print('Initialize tracking!')
          self.pre_images = images
          self.tracker.init_track(
            meta['pre_dets'] if 'pre_dets' in meta else [])

只针对第一帧。init_track主要是初始化检测出的总人数id_count/tracking_id/bbox/ct。每个行人的这些属性都会更新到self.tracks中。但第一帧results为[],故输出[]。这个函数好像没有太大意义。

# tracker.py
  def init_track(self, results):
    for item in results:
      # 得分大于阈值,则代表检测到行人,id_count的总数加一
      if item['score'] > self.opt.new_thresh:
        self.id_count += 1
        # 这篇文章里没有用到active和age 
        item['active'] = 1
        item['age'] = 1
        
        item['tracking_id'] = self.id_count
        if not ('ct' in item):
          bbox = item['bbox']
          item['ct'] = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
        self.tracks.append(item)

回到detector.run,get_additional_inputs函数主要是获取前一帧hm

        if self.opt.pre_hm:
          # render input heatmap from tracker status
          # 这个版本没有用pre_inds.
          # pre_inds是用来学习前一帧图片到当前帧图片的偏差
          pre_hms, pre_inds = self._get_additional_inputs(
            self.tracker.tracks, meta, with_hm=not self.opt.zero_pre_hm)

self.tracker.track代表上一帧的检测结果,看一下列表中的某个元素

 {'score': 0.4982151,
  'class': 1,
  'ct': array([ 48., 636.], dtype=float32),
  'tracking': array([2.617714 , 4.9918213], dtype=float32),
  'bbox': array([-158.66681,  402.7378 ,  106.68776,  882.9254 ], dtype=float32),
  'tracking_id': 19,
  'age': 1,
  'active': 1}]

_get_additional_inputs函数主要是获取前一帧图像的hm(1*1*544*960*)以及inds,与训练时数据处理部分相同。

# detector.py
  def _get_additional_inputs(self, dets, meta, with_hm=True):
    '''
    Render input heatmap from previous trackings.
    '''
    trans_input, trans_output = meta['trans_input'], meta['trans_output']
    inp_width, inp_height = meta['inp_width'], meta['inp_height']
    out_width, out_height = meta['out_width'], meta['out_height']
    input_hm = np.zeros((1, inp_height, inp_width), dtype=np.float32)

    output_inds = []
    for det in dets:
      if det['score'] < self.opt.pre_thresh or det['active'] == 0:
        continue
      bbox = self._trans_bbox(det['bbox'], trans_input, inp_width, inp_height)
      bbox_out = self._trans_bbox(
        det['bbox'], trans_output, out_width, out_height)
      h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
      if (h > 0 and w > 0):
        radius = gaussian_radius((math.ceil(h), math.ceil(w)))
        radius = max(0, int(radius))
        ct = np.array(
          [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32)
        ct_int = ct.astype(np.int32)
        if with_hm:
          draw_umich_gaussian(input_hm[0], ct_int, radius)
        ct_out = np.array(
          [(bbox_out[0] + bbox_out[2]) / 2, 
           (bbox_out[1] + bbox_out[3]) / 2], dtype=np.int32)
        output_inds.append(ct_out[1] * out_width + ct_out[0])
    if with_hm:
      input_hm = input_hm[np.newaxis]
      if self.opt.flip_test:
        input_hm = np.concatenate((input_hm, input_hm[:, :, :, ::-1]), axis=0)
      input_hm = torch.from_numpy(input_hm).to(self.opt.device)
    output_inds = np.array(output_inds, np.int64).reshape(1, -1)
    output_inds = torch.from_numpy(output_inds).to(self.opt.device)
    return input_hm, output_inds

回到detector.run。得到pre_hm之后,self.process主要是根据pre_hm,以及imgs,前向提取特征得到output。并经过decode解码,处理完之后得到当前帧的检测结果。

      pre_process_time = time.time()
      pre_time += pre_process_time - scale_start_time
      
      # run the network
      # output: the output feature maps, only used for visualizing
      # dets: output tensors after extracting peaks
      output, dets, forward_time = self.process(
        images, self.pre_images, pre_hms, pre_inds, return_time=True)

这里的decode和centernet一致,就不放code了。

  def process(self, images, pre_images=None, pre_hms=None,
    pre_inds=None, return_time=False):
    with torch.no_grad():
      torch.cuda.synchronize()
      output = self.model(images, pre_images, pre_hms)[-1]
      output = self._sigmoid_output(output)
      output.update({'pre_inds': pre_inds})
      if self.opt.flip_test:
        output = self._flip_output(output)
      torch.cuda.synchronize()
      forward_time = time.time()
      
      dets = generic_decode(output, K=self.opt.K, opt=self.opt)
      torch.cuda.synchronize()
      for k in dets:
        dets[k] = dets[k].detach().cpu().numpy()
    if return_time:
      return output, dets, forward_time
    else:
      return output, dets

看一下process得到有哪些结果。除了pre_cts是上一帧的人数,其他都输出的top-100结果。

In [2]: dets.keys()  
Out[2]: dict_keys(['scores', 'clses', 'xs', 'ys', 'cts', 'bboxes', 'tracking', 'bboxes_amodal', 'pre_cts'])

回到detector.run。self.post_process主要是转换到原图坐标系,并筛选得分较低的bbox。

      net_time += forward_time - pre_process_time
      decode_time = time.time()
      dec_time += decode_time - forward_time
      
      # 把crop&4倍降采样的坐标系转换为输入图片对应坐标系,同时用score>opt.out_thresh进行筛选
      result = self.post_process(dets, meta, scale)

这里看下列表result内元素

In [10]: result[0]
Out[10]: 
{'score': 0.97548306,
 'class': 1,
 'ct': array([1104.,  708.], dtype=float32),
 'tracking': array([1.4091797, 2.8463135], dtype=float32),
 'bbox': array([1021.6944,  403.1587, 1191.8154, 1020.0564], dtype=float32)}

回到detector.run。得到了不同scale下的检测结果,进行融合。并进入self.tracker.step进行匹配。

      post_process_time = time.time()
      post_time += post_process_time - decode_time

      detections.append(result)

    # for循环结束,融合multi-scale的检测结果
    results = self.merge_outputs(detections)
    torch.cuda.synchronize()
    end_time = time.time()
    merge_time += end_time - post_process_time
    
    if self.opt.tracking:
      public_det = None
      # add tracking id to results
      results = self.tracker.step(results, public_det)
      self.pre_images = images

这里self.tracker.step主要是对当前帧结果和上一帧结果进行匹配,并且更新上一帧的信息为当前帧,下次迭代使用。

  def step(self, results, public_det=None):
    # 当前帧人数
    N = len(results)
    # 上一帧人数
    M = len(self.tracks)
	
    # 当前帧中心
    dets = np.array(
      [det['ct'] + det['tracking'] for det in results], np.float32) # N x 2
      
    # 上一帧bbox面积
    track_size = np.array([((track['bbox'][2] - track['bbox'][0]) * \
      (track['bbox'][3] - track['bbox'][1])) \
      for track in self.tracks], np.float32) # M
    
    # 上一帧类别
    track_cat = np.array([track['class'] for track in self.tracks], np.int32) # M
    # 当前帧bbox面积
    item_size = np.array([((item['bbox'][2] - item['bbox'][0]) * \
      (item['bbox'][3] - item['bbox'][1])) \
      for item in results], np.float32) # N
      
    # 当前帧类别
    item_cat = np.array([item['class'] for item in results], np.int32) # N
    
    # 上一帧中心
    tracks = np.array(
      [pre_det['ct'] for pre_det in self.tracks], np.float32) # M x 2
      
    # 前后两帧中心之间的距离 N*M
    dist = (((tracks.reshape(1, -1, 2) - \
              dets.reshape(-1, 1, 2)) ** 2).sum(axis=2)) # N x M

    # 排除距离太远的和类别不同,对应dist置为无穷大
    invalid = ((dist > track_size.reshape(1, M)) + \
      (dist > item_size.reshape(N, 1)) + \
      (item_cat.reshape(N, 1) != track_cat.reshape(1, M))) > 0
    dist = dist + invalid * 1e18
    
    if self.opt.hungarian: # false
      item_score = np.array([item['score'] for item in results], np.float32) # N
      dist[dist > 1e18] = 1e18
      matched_indices = linear_assignment(dist)
    else:
      # 配对后的index,n*2
      matched_indices = greedy_assignment(copy.deepcopy(dist))
      
    # 没有成功配对的bbox
    unmatched_dets = [d for d in range(dets.shape[0]) \
      if not (d in matched_indices[:, 0])]
    unmatched_tracks = [d for d in range(tracks.shape[0]) \
      if not (d in matched_indices[:, 1])]
    
    if self.opt.hungarian: # false
      matches = []
      for m in matched_indices:
        if dist[m[0], m[1]] > 1e16:
          unmatched_dets.append(m[0])
          unmatched_tracks.append(m[1])
        else:
          matches.append(m)
      matches = np.array(matches).reshape(-1, 2)
    else:
      matches = matched_indices

    ret = []
    for m in matches:
      # 当前帧的信息赋给track,下一次迭代使用
      track = results[m[0]]
      # 配对成功的继承tracking_id
      track['tracking_id'] = self.tracks[m[1]]['tracking_id']
      track['age'] = 1
      track['active'] = self.tracks[m[1]]['active'] + 1
      ret.append(track)

    # Private detection: create tracks for all un-matched detections
    # 对于没有配对成功的,视为新加入视频帧的行人,重建track
    for i in unmatched_dets:
      track = results[i]
      if track['score'] > self.opt.new_thresh:
        self.id_count += 1
        track['tracking_id'] = self.id_count
        track['age'] = 1
        track['active'] =  1
        ret.append(track)
    
    for i in unmatched_tracks:
      track = self.tracks[i]
      if track['age'] < self.opt.max_age:
        track['age'] += 1
        track['active'] = 0
        bbox = track['bbox']
        ct = track['ct']
        v = [0, 0]
        track['bbox'] = [
          bbox[0] + v[0], bbox[1] + v[1],
          bbox[2] + v[0], bbox[3] + v[1]]
        track['ct'] = [ct[0] + v[0], ct[1] + v[1]]
        ret.append(track)
        
    # ret的信息来更新self.tracks
    self.tracks = ret
    return ret

回到detector.run,tracking的整个流程结束,返回ret。


    tracking_time = time.time()
    track_time += tracking_time - end_time
    tot_time += tracking_time - start_time
    self.cnt += 1
    show_results_time = time.time()
    display_time += show_results_time - end_time
    
    # return results and run time
    ret = {'results': results, 'tot': tot_time, 'load': load_time,
            'pre': pre_time, 'net': net_time, 'dec': dec_time,
            'post': post_time, 'merge': merge_time, 'track': track_time,
            'display': display_time}
    return ret
In [7]: ret
Out[8]:                                                       
{'results': 
[
      {'score': 0.9737419,
       'class': 1,
       'ct': array([1704.,  652.], dtype=float32),
       'tracking': array([4.6469727, 4.75177  ], dtype=float32),
       'bbox': array([1631.5349 ,  422.77515, 1785.3937 ,  891.9469 ], dtype=float32),
       'tracking_id': 1,
       'age': 1,
       'active': 1},

         ...

       {'score': 0.4982151,
       'class': 1,
       'ct': array([ 48., 636.], dtype=float32),
       'tracking': array([2.617714 , 4.9918213], dtype=float32),
       'bbox': array([-158.66681,  402.7378 ,  106.68776,  882.9254 ], dtype=float32),
       'tracking_id': 19,
       'age': 1,
       'active': 1}
 ],
 'tot': 0.7928645610809326,
 'load': 3.910064697265625e-05,
 'pre': 0.002444744110107422,
 'net': 0.7664916515350342,
 'dec': 0.021967411041259766,
 'post': 0.0013659000396728516,
 'merge': 7.224082946777344e-05,
 'track': 0.0004775524139404297,
 'display': 0.00048422813415527344}