CenterTrack代码解读——网络结构部分

761 阅读1分钟

代码地址: github.com/xingyizhou/…

论文地址:arxiv.org/pdf/2004.01…

网络结构部分

以dla34网络为例,由于大部分代码和centernet相同,这里主要列出不同部分。主体在dla.py

DLA类中,除了原始提取x特征的网络结构,增加了self.pre_img_layer和self.pre_hm_layer对前一帧img和hm提取特征。

class DLA(nn.Module):
    def __init__(self, levels, channels, num_classes=1000,
                 block=BasicBlock, residual_root=False, linear_root=False,
                 opt=None):
        if opt.pre_img:
            self.pre_img_layer = nn.Sequential(
            nn.Conv2d(3, channels[0], kernel_size=7, stride=1,
                      padding=3, bias=False),
            nn.BatchNorm2d(channels[0], momentum=BN_MOMENTUM),
            nn.ReLU(inplace=True))
        if opt.pre_hm:
            self.pre_hm_layer = nn.Sequential(
            nn.Conv2d(1, channels[0], kernel_size=7, stride=1,
                    padding=3, bias=False),
            nn.BatchNorm2d(channels[0], momentum=BN_MOMENTUM),
            nn.ReLU(inplace=True))
            
    def forward(self, x, pre_img=None, pre_hm=None):
        y = []
        x = self.base_layer(x)
       
        # 新增层,输出与x相加融合后送入后续网络
        if pre_img is not None:
            x = x + self.pre_img_layer(pre_img)
        if pre_hm is not None:
            x = x + self.pre_hm_layer(pre_hm)
        
        for i in range(6):
            x = getattr(self, 'level{}'.format(i))(x)
            y.append(x)
        return y

看BaseModel类里的foward, 与centernet基本一样,imgpre2feats内经过DLAbase+DLAUp+IDAUp之后,接各个head。

    def forward(self, x, pre_img=None, pre_hm=None):
      if (pre_hm is not None) or (pre_img is not None):
        feats = self.imgpre2feats(x, pre_img, pre_hm)
      else:
        feats = self.img2feats(x)
      out = []
      
      # 处理各个head
      for s in range(self.num_stacks):
        z = {}
        for head in self.heads:
            z[head] = self.__getattr__(head)(feats[s])
        out.append(z)
    return out
    def imgpre2feats(self, x, pre_img=None, pre_hm=None):
        x = self.base(x, pre_img, pre_hm)
        x = self.dla_up(x)

        y = []
        for i in range(self.last_level - self.first_level):
            y.append(x[i].clone())
        self.ida_up(y, 0, len(y))

        return [y[-1]]

head部分包括{'hm': 1, 'reg': 2, 'wh': 2, 'tracking': 2, 'ltrb_amodal': 4},结构均为conv_relu_conv。不同head的输入均为y[-1] (n*64*136*240),已经融合了不同层的信息。各head输出维度如下:

In [1]: out[0]['hm'].shape
Out[1]: torch.Size([10, 1, 136, 240])

In [2]: out[0]['reg'].shape
Out[2]: torch.Size([10, 2, 136, 240])

In [3]: out[0]['wh'].shape
Out[3]: torch.Size([10, 2, 136, 240])

In [4]: out[0]['tracking'].shape
Out[4]: torch.Size([10, 2, 136, 240])

In [5]: out[0]['ltrb_amodal'].shape
Out[5]: torch.Size([10, 4, 136, 240])