源码URL: github.com/michuanhaoh…
HA-CNN的网络模型的代码注释如下:
from __future__ import absolute_import
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
__all__ = ['HACNN']
class ConvBlock(nn.Module):
"""Basic convolutional block:
convolution + batch normalization + relu.
Args (following http://pytorch.org/docs/master/nn.html#torch.nn.Conv2d):
in_c (int): number of input channels.
out_c (int): number of output channels.
k (int or tuple): kernel size.
s (int or tuple): stride.
p (int or tuple): padding.
"""
def __init__(self, in_c, out_c, k, s=1, p=0):
super(ConvBlock, self).__init__()
self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p)
self.bn = nn.BatchNorm2d(out_c)
def forward(self, x):
return F.relu(self.bn(self.conv(x)))
class InceptionA(nn.Module):
"""
Args:
in_channels (int): number of input channels
out_channels (int): number of output channels AFTER concatenation
"""
def __init__(self, in_channels, out_channels):
super(InceptionA, self).__init__()
single_out_channels = out_channels // 4 # 最后四条分支合并变成4倍,所以先除4
self.stream1 = nn.Sequential(
ConvBlock(in_channels, single_out_channels, 1),
ConvBlock(single_out_channels, single_out_channels, 3, p=1),
)
self.stream2 = nn.Sequential(
ConvBlock(in_channels, single_out_channels, 1),
ConvBlock(single_out_channels, single_out_channels, 3, p=1),
)
self.stream3 = nn.Sequential(
ConvBlock(in_channels, single_out_channels, 1),
ConvBlock(single_out_channels, single_out_channels, 3, p=1),
)
self.stream4 = nn.Sequential(
nn.AvgPool2d(3, stride=1, padding=1),
ConvBlock(in_channels, single_out_channels, 1),
)
def forward(self, x):
s1 = self.stream1(x)
s2 = self.stream2(x)
s3 = self.stream3(x)
s4 = self.stream4(x)
y = torch.cat([s1, s2, s3, s4], dim=1)
return y
class InceptionB(nn.Module):
"""
Args:
in_channels (int): number of input channels
out_channels (int): number of output channels AFTER concatenation
"""
def __init__(self, in_channels, out_channels):
super(InceptionB, self).__init__()
single_out_channels = out_channels // 4
self.stream1 = nn.Sequential(
ConvBlock(in_channels, single_out_channels, 1),
ConvBlock(single_out_channels, single_out_channels, 3, s=2, p=1),
)
self.stream2 = nn.Sequential(
ConvBlock(in_channels, single_out_channels, 1),
ConvBlock(single_out_channels, single_out_channels, 3, p=1),
ConvBlock(single_out_channels, single_out_channels, 3, s=2, p=1),
)
self.stream3 = nn.Sequential(
nn.MaxPool2d(3, stride=2, padding=1),
ConvBlock(in_channels, single_out_channels*2, 1),
)
def forward(self, x):
s1 = self.stream1(x)
s2 = self.stream2(x)
s3 = self.stream3(x)
y = torch.cat([s1, s2, s3], dim=1)
return y
class SpatialAttn(nn.Module):
"""Spatial Attention (Sec. 3.1.I.1)"""
def __init__(self):
super(SpatialAttn, self).__init__()
self.conv1 = ConvBlock(1, 1, 3, s=2, p=1)
self.conv2 = ConvBlock(1, 1, 1)
def forward(self, x):
# global cross-channel averaging
x = x.mean(1, keepdim=True)
# 3-by-3 conv
x = self.conv1(x)
# bilinear resizing
x = F.upsample(x, (x.size(2)*2, x.size(3)*2), mode='bilinear', align_corners=True)
# scaling conv
x = self.conv2(x)
return x
class ChannelAttn(nn.Module):
"""Channel Attention (Sec. 3.1.I.2)"""
def __init__(self, in_channels, reduction_rate=16):
super(ChannelAttn, self).__init__()
assert in_channels%reduction_rate == 0
self.conv1 = ConvBlock(in_channels, in_channels//reduction_rate, 1)
self.conv2 = ConvBlock(in_channels//reduction_rate, in_channels, 1)
def forward(self, x):
# squeeze operation (global average pooling)
x = F.avg_pool2d(x, x.size()[2:]) # filter size取x.size的后两位(height, width)
# excitation operation (2 conv layers)
x = self.conv1(x)
x = self.conv2(x)
return x
class SoftAttn(nn.Module):
"""Soft Attention (Sec. 3.1.I)
Aim: Spatial Attention + Channel Attention
Output: attention maps with shape identical to input.
"""
def __init__(self, in_channels):
super(SoftAttn, self).__init__()
self.spatial_attn = SpatialAttn()
self.channel_attn = ChannelAttn(in_channels)
self.conv = ConvBlock(in_channels, in_channels, 1)
def forward(self, x):
y_spatial = self.spatial_attn(x)
y_channel = self.channel_attn(x)
y = y_spatial * y_channel
y = torch.sigmoid(self.conv(y))
return y
class HardAttn(nn.Module):
"""Hard Attention (Sec. 3.1.II)"""
def __init__(self, in_channels):
super(HardAttn, self).__init__()
self.fc = nn.Linear(in_channels, 4*2)
self.init_params()
def init_params(self):
self.fc.weight.data.zero_() # 权重清空 .data可以调用其中的权值
self.fc.bias.data.copy_(torch.tensor([0, -0.75, 0, -0.25, 0, 0.25, 0, 0.75], dtype=torch.float)) # 将bias赋值为给定的tensor作为初始化
def forward(self, x):
# squeeze operation (global average pooling)
x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), x.size(1))
# predict transformation parameters
theta = torch.tanh(self.fc(x))
theta = theta.view(-1, 4, 2) # -1意思是自适应, channel的8维被拆成了channel=4, height=2
return theta
class HarmAttn(nn.Module):
"""Harmonious Attention (Sec. 3.1)"""
def __init__(self, in_channels):
super(HarmAttn, self).__init__()
self.soft_attn = SoftAttn(in_channels)
self.hard_attn = HardAttn(in_channels)
def forward(self, x):
y_soft_attn = self.soft_attn(x)
theta = self.hard_attn(x)
return y_soft_attn, theta
class HACNN(nn.Module):
"""
Harmonious Attention Convolutional Neural Network
Reference:
Li et al. Harmonious Attention Network for Person Re-identification. CVPR 2018.
Args:
num_classes (int): number of classes to predict
nchannels (list): number of channels AFTER concatenation
feat_dim (int): feature dimension for a single stream
learn_region (bool): whether to learn region features (i.e. local branch)
"""
def __init__(self, num_classes, loss={'xent'}, nchannels=[128, 256, 384], feat_dim=512, learn_region=True, use_gpu=True, **kwargs):
super(HACNN, self).__init__()
self.loss = loss
self.learn_region = learn_region # hard attention是否执行标志位
self.use_gpu = use_gpu
self.conv = ConvBlock(3, 32, 3, s=2, p=1)
# Construct Inception + HarmAttn blocks
# ============== Block 1 ==============
self.inception1 = nn.Sequential( # (b, 32, h, w) >> (b, 128, h/2, w/2)
InceptionA(32, nchannels[0]),
InceptionB(nchannels[0], nchannels[0]),
)
self.ha1 = HarmAttn(nchannels[0])
# ============== Block 2 ==============
self.inception2 = nn.Sequential(
InceptionA(nchannels[0], nchannels[1]),
InceptionB(nchannels[1], nchannels[1]),
)
self.ha2 = HarmAttn(nchannels[1])
# ============== Block 3 ==============
self.inception3 = nn.Sequential(
InceptionA(nchannels[1], nchannels[2]),
InceptionB(nchannels[2], nchannels[2]),
)
self.ha3 = HarmAttn(nchannels[2])
self.fc_global = nn.Sequential(
nn.Linear(nchannels[2], feat_dim),
nn.BatchNorm1d(feat_dim),
nn.ReLU(),
)
self.classifier_global = nn.Linear(feat_dim, num_classes)
if self.learn_region:
self.init_scale_factors()
self.local_conv1 = InceptionB(32, nchannels[0])
self.local_conv2 = InceptionB(nchannels[0], nchannels[1])
self.local_conv3 = InceptionB(nchannels[1], nchannels[2])
self.fc_local = nn.Sequential(
nn.Linear(nchannels[2]*4, feat_dim),
nn.BatchNorm1d(feat_dim),
nn.ReLU(),
)
self.classifier_local = nn.Linear(feat_dim, num_classes)
self.feat_dim = feat_dim * 2
else:
self.feat_dim = feat_dim
def init_scale_factors(self):
# initialize scale factors (s_w, s_h) for four regions
self.scale_factors = []
self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)) # 把4个矩阵[[1, 0], [0, 0.25]]依次存入列表中
self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float))
self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float))
self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float))
def stn(self, x, theta):
"""Perform spatial transform
x: (batch, channel, height, width)
theta: (batch, 2, 3)
"""
grid = F.affine_grid(theta, x.size()) # 仿射变换,theta要求(2,3)矩阵,grid: (batch, height, batch, 2)
x1 = F.grid_sample(x, grid) # x1.shape = x.shape
return x1
def transform_theta(self, theta_i, region_idx):
"""Transform theta to include (s_w, s_h),
resulting in (batch, 2, 3)"""
scale_factors = self.scale_factors[region_idx] # 取一个[[1, 0], [0, 0.25]]矩阵
theta = torch.zeros(theta_i.size(0), 2, 3) # 构造(batch, 2, 3)的0矩阵
theta[:,:,:2] = scale_factors # 令(batch, 2, 3)中的(batch, 2, 2)的(2,2)部分等于scale_factors矩阵
theta[:,:,-1] = theta_i # 令(batch, 2, 3)中的剩下(batch, 2, 1)/(batch, 2) 的(2, 1)/(2)部分等于theta_i
if self.use_gpu: theta = theta.cuda()
return theta # 重构后返回theta: (batch, 2, 3)
def forward(self, x):
assert x.size(2) == 160 and x.size(3) == 64, \
"Input size does not match, expected (160, 64) but got ({}, {})".format(x.size(2), x.size(3))
x = self.conv(x) # x: (b, 3, 160, 64) >> (b, 32, 80, 32)
# ============== Block 1 ==============
# global branch
x1 = self.inception1(x) # x1: (b, 128, 40, 16)
x1_attn, x1_theta = self.ha1(x1) # x1_attn: (b, 128, 40, 16) , x1_theta: (b, 4, 2)
x1_out = x1 * x1_attn # x1_out: (32, 128, 40, 16)
# local branch
if self.learn_region:
x1_local_list = []
for region_idx in range(4):
x1_theta_i1 = x1_theta[:,region_idx,:] # 每个循环遍历一个channel, x1_theta_i: (32, 2)
x1_theta_i2 = self.transform_theta(x1_theta_i1, region_idx) # x1_theta_i2: (32, 2, 3)。相当于在原有列向量基础上,左边添加了[[1, 0], [0, 0.25]]方阵
x1_trans_i3 = self.stn(x, x1_theta_i2) # 用theta对x进行仿射变换, x1_trans_i : (b, 32, 80, 32)
x1_trans_i4 = F.upsample(x1_trans_i3, (24, 28), mode='bilinear', align_corners=True) # x1_trans_i: (b, 32, 24, 28)
x1_local_i = self.local_conv1(x1_trans_i4) # x1_local_i: (b, 128, 24/2=12, 28/2=14)
x1_local_list.append(x1_local_i) # x1_local_list[0].shape : (b, 128, 12, 14)
# ============== Block 2 ==============
# Block 2
# global branch
x2 = self.inception2(x1_out) # x2: (b, 256, 20, 8)
x2_attn, x2_theta = self.ha2(x2) # x2_attn: (b, 256, 20, 8), x2_theta: (b, 4, 2)
x2_out = x2 * x2_attn # x2_out: (b, 256, 20, 8)
# local branch
if self.learn_region:
x2_local_list = []
for region_idx in range(4):
x2_theta_i1 = x2_theta[:,region_idx,:] # (b, 2)
x2_theta_i2 = self.transform_theta(x2_theta_i1, region_idx) # x2_theta_i2: (b, 2, 3)
x2_trans_i3 = self.stn(x1_out, x2_theta_i2) # (b, 128, 40, 16)
x2_trans_i4 = F.upsample(x2_trans_i3, (12, 14), mode='bilinear', align_corners=True) # (b, 128, 12, 14)
x2_local_i5 = x2_trans_i4 + x1_local_list[region_idx] # 和Block 1的同序号local特征加和得到, (b, 128, 12, 14) + (b, 128, 12, 14)
x2_local_i = self.local_conv2(x2_local_i5) # (b, 256, 6, 7)
x2_local_list.append(x2_local_i)
# ============== Block 3 ==============
# Block 3
# global branch
x3 = self.inception3(x2_out) # (b, 384, 10, 4)
x3_attn, x3_theta = self.ha3(x3) # (b, 384, 10, 4), (b, 4, 2)
x3_out = x3 * x3_attn # (b, 384, 10, 4)
# local branch
if self.learn_region:
x3_local_list = []
for region_idx in range(4):
x3_theta_i1 = x3_theta[:,region_idx,:] # (b, 2)
x3_theta_i2 = self.transform_theta(x3_theta_i1, region_idx) # (b, 2, 3)
x3_trans_i3 = self.stn(x2_out, x3_theta_i2) # (b, 256, 20, 8)
x3_trans_i4 = F.upsample(x3_trans_i3, (6, 7), mode='bilinear', align_corners=True) # (b, 256, 6, 7)
x3_local_i5 = x3_trans_i4 + x2_local_list[region_idx] # (b, 256, 6, 7) + (b, 256, 6, 7)
x3_local_i = self.local_conv3(x3_local_i5) # (b, 386, 3, 4)
x3_local_list.append(x3_local_i)
# ============== Feature generation ==============
# global branch
x_global1 = F.avg_pool2d(x3_out, x3_out.size()[2:]).view(x3_out.size(0), x3_out.size(1)) # (b, 384, 10, 4) >> (b, 384)
x_global = self.fc_global(x_global1) # (b, 384) >>> (b, 512)
# local branch
if self.learn_region:
x_local_list = []
for region_idx in range(4):
x_local_i1 = x3_local_list[region_idx] # (32, 384, 3, 4)
x_local_i = F.avg_pool2d(x_local_i1, x_local_i1.size()[2:]).view(x_local_i1.size(0), -1) # (32, 384)
x_local_list.append(x_local_i) # x_local_list[0].shape: (32, 384)
x_local0 = torch.cat(x_local_list, 1) # len(x_local_list)=4, x3_local_list的4个特征图堆叠得到x_local: (32, 4*384=1536)
x_local = self.fc_local(x_local0) # (32, 1536) >> (32, feat_dim)
if not self.training:
# l2 normalization before concatenation
if self.learn_region:
x_global = x_global / x_global.norm(p=2, dim=1, keepdim=True)
x_local = x_local / x_local.norm(p=2, dim=1, keepdim=True)
return torch.cat([x_global, x_local], 1)
else:
return x_global
prelogits_global = self.classifier_global(x_global)
if self.learn_region:
prelogits_local = self.classifier_local(x_local)
if self.loss == {'xent'}:
if self.learn_region:
return (prelogits_global, prelogits_local)
else:
return prelogits_global
elif self.loss == {'xent', 'htri'}:
if self.learn_region:
return (prelogits_global, prelogits_local), (x_global, x_local)
else:
return prelogits_global, x_global
else:
raise KeyError("Unsupported loss: {}".format(self.loss))
if __name__ == '__main__':
input = torch.Tensor(32, 3, 160, 64).cuda()
cnn = HACNN(751).cuda()
y = cnn(input)
from IPython import embed
embed()