代码地址: github.com/xingyizhou/…
网络结构部分
以dla34网络为例,由于大部分代码和centernet相同,这里主要列出不同部分。主体在dla.py
DLA类中,除了原始提取x特征的网络结构,增加了self.pre_img_layer和self.pre_hm_layer对前一帧img和hm提取特征。
class DLA(nn.Module):
def __init__(self, levels, channels, num_classes=1000,
block=BasicBlock, residual_root=False, linear_root=False,
opt=None):
if opt.pre_img:
self.pre_img_layer = nn.Sequential(
nn.Conv2d(3, channels[0], kernel_size=7, stride=1,
padding=3, bias=False),
nn.BatchNorm2d(channels[0], momentum=BN_MOMENTUM),
nn.ReLU(inplace=True))
if opt.pre_hm:
self.pre_hm_layer = nn.Sequential(
nn.Conv2d(1, channels[0], kernel_size=7, stride=1,
padding=3, bias=False),
nn.BatchNorm2d(channels[0], momentum=BN_MOMENTUM),
nn.ReLU(inplace=True))
def forward(self, x, pre_img=None, pre_hm=None):
y = []
x = self.base_layer(x)
# 新增层,输出与x相加融合后送入后续网络
if pre_img is not None:
x = x + self.pre_img_layer(pre_img)
if pre_hm is not None:
x = x + self.pre_hm_layer(pre_hm)
for i in range(6):
x = getattr(self, 'level{}'.format(i))(x)
y.append(x)
return y
看BaseModel类里的foward, 与centernet基本一样,imgpre2feats内经过DLAbase+DLAUp+IDAUp之后,接各个head。
def forward(self, x, pre_img=None, pre_hm=None):
if (pre_hm is not None) or (pre_img is not None):
feats = self.imgpre2feats(x, pre_img, pre_hm)
else:
feats = self.img2feats(x)
out = []
# 处理各个head
for s in range(self.num_stacks):
z = {}
for head in self.heads:
z[head] = self.__getattr__(head)(feats[s])
out.append(z)
return out
def imgpre2feats(self, x, pre_img=None, pre_hm=None):
x = self.base(x, pre_img, pre_hm)
x = self.dla_up(x)
y = []
for i in range(self.last_level - self.first_level):
y.append(x[i].clone())
self.ida_up(y, 0, len(y))
return [y[-1]]
head部分包括{'hm': 1, 'reg': 2, 'wh': 2, 'tracking': 2, 'ltrb_amodal': 4},结构均为conv_relu_conv。不同head的输入均为y[-1] (n*64*136*240),已经融合了不同层的信息。各head输出维度如下:
In [1]: out[0]['hm'].shape
Out[1]: torch.Size([10, 1, 136, 240])
In [2]: out[0]['reg'].shape
Out[2]: torch.Size([10, 2, 136, 240])
In [3]: out[0]['wh'].shape
Out[3]: torch.Size([10, 2, 136, 240])
In [4]: out[0]['tracking'].shape
Out[4]: torch.Size([10, 2, 136, 240])
In [5]: out[0]['ltrb_amodal'].shape
Out[5]: torch.Size([10, 4, 136, 240])