源码url: github.com/KaiyangZhou…
阅读辽OSNet模型稍微注释了一下
from __future__ import division, absolute_import
import warnings
import torch
from torch import nn
from torch.nn import functional as F
__all__ = ['osnet_x1_0', 'osnet_x0_75', 'osnet_x0_5', 'osnet_x0_25', 'osnet_ibn_x1_0']
pretrained_urls = {
'osnet_x1_0':
'https://drive.google.com/uc?id=1LaG1EJpHrxdAxKnSCJ_i0u-nbxSAeiFY',
'osnet_x0_75':
'https://drive.google.com/uc?id=1uwA9fElHOk3ZogwbeY5GkLI6QPTX70Hq',
'osnet_x0_5':
'https://drive.google.com/uc?id=16DGLbZukvVYgINws8u8deSaOqjybZ83i',
'osnet_x0_25':
'https://drive.google.com/uc?id=1rb8UN5ZzPKRc_xvtHlyDh-cSz88YX9hs',
'osnet_ibn_x1_0':
'https://drive.google.com/uc?id=1sr90V6irlYYDd4_4ISU2iruoRG8J__6l'
}
##########
# Basic layers
##########
class ConvLayer(nn.Module):
"""
Convolution layer (conv + bn + relu).
Args:
in_channels -> Int
out_channels -> Int
kernel_size -> Int
stride -> Int
padding -> Int
groups -> Int
IN -> bool: Use Instance Normalization or not
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, IN=False):
super(ConvLayer, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
bias=False, groups=groups)
if IN:
self.bn = nn.InstanceNorm2d(out_channels, affine=True)
else:
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class Conv1x1(nn.Module):
"""1x1 convolution + bn + relu."""
def __init__(self, in_channels, out_channels, stride=1, groups=1):
super(Conv1x1, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 1, stride=stride, padding=0, bias=False, groups=groups)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class Conv1x1Linear(nn.Module):
"""1x1 convolution + bn (w/o non-linearity)."""
def __init__(self, in_channels, out_channels, stride=1):
super(Conv1x1Linear, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 1, stride=stride, padding=0, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class Conv3x3(nn.Module):
"""3x3 convolution + bn + relu."""
def __init__(self, in_channels, out_channels, stride=1, groups=1):
super(Conv3x3, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False, groups=groups)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class LightConv3x3(nn.Module):
"""Lightweight 3x3 convolution.
1x1 (linear) + dw 3x3 (nonlinear).
"""
def __init__(self, in_channels, out_channels):
super(LightConv3x3, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False)
# depth wise 3x3 conv
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=False, groups=out_channels)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.bn(x)
x = self.relu(x)
return x
##########
# Building blocks for omni-scale feature learning
##########
class ChannelGate(nn.Module):
"""
A mini-network that generates channel-wise gates conditioned on input tensor.
Args:
in_channels -> Int: input channels
num_gates -> None/Int: default=None, if not, num_gates = a(Int) = output channels
return_gates -> bool: default=False, if not, we would get channel attention map output directly
without imposing on input tensor
gate_activation -> String: default='sigmoid', activation function regulation
reduction -> Int: default=16, reduction rate in the bottleneck architecture
layer_norm -> bool: default=False, decide whether use normalization or not before ReLU in the bottleneck
"""
def __init__(self, in_channels, num_gates=None, return_gates=False, gate_activation='sigmoid',
reduction=16, layer_norm=False):
super(ChannelGate, self).__init__()
if num_gates is None:
num_gates = in_channels
self.return_gates = return_gates
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1, bias=True, padding=0)
self.norm1 = None
if layer_norm:
# nn.LayerNorm: channel方向做归一化,算CHW的均值,主要对RNN作用明显
self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1))
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(in_channels // reduction, num_gates, kernel_size=1, bias=True, padding=0)
if gate_activation == 'sigmoid':
self.gate_activation = nn.Sigmoid()
elif gate_activation == 'relu':
self.gate_activation = nn.ReLU(inplace=True)
elif gate_activation == 'linear':
self.gate_activation = None
else:
raise RuntimeError("Unknown gate activation: {}".format(gate_activation))
def forward(self, x):
input = x
x = self.global_avgpool(x)
x = self.fc1(x) # 1x1Conv(in_channel, in_channel//16)
if self.norm1 is not None: # 如果有归一化就调用
x = self.norm1(x)
x = self.relu(x)
x = self.fc2(x) # 1x1Conv(in_channel//16, in_channel)
if self.gate_activation is not None:
x = self.gate_activation(x)
if self.return_gates:
return x
return input * x
class OSBlock(nn.Module):
"""Omni-scale feature learning block."""
def __init__(self, in_channels, out_channels, IN=False, bottleneck_reduction=4, **kwargs):
super(OSBlock, self).__init__()
mid_channels = out_channels // bottleneck_reduction
self.conv1 = Conv1x1(in_channels, mid_channels) # 1x1Conv + bn + relu
self.conv2a = LightConv3x3(mid_channels, mid_channels)
self.conv2b = nn.Sequential(
LightConv3x3(mid_channels, mid_channels),
LightConv3x3(mid_channels, mid_channels),
)
self.conv2c = nn.Sequential(
LightConv3x3(mid_channels, mid_channels),
LightConv3x3(mid_channels, mid_channels),
LightConv3x3(mid_channels, mid_channels),
)
self.conv2d = nn.Sequential(
LightConv3x3(mid_channels, mid_channels),
LightConv3x3(mid_channels, mid_channels),
LightConv3x3(mid_channels, mid_channels),
LightConv3x3(mid_channels, mid_channels),
)
self.gate = ChannelGate(mid_channels)
self.conv3 = Conv1x1Linear(mid_channels, out_channels) # 1x1Conv + bn
self.downsample = None
if in_channels != out_channels: # 若不是channel的等额传递
self.downsample = Conv1x1Linear(in_channels, out_channels) # 令下采样函数取1x1Conv+bn
self.IN = None
if IN:
# BatchNorm:batch方向做归一化,算NHW的均值,对小batchsize效果不好
# nn.InstanceNorm2d: 一个channel内做归一化,算H*W的均值,用在风格化迁移
self.IN = nn.InstanceNorm2d(out_channels, affine=True)
def forward(self, x):
identity = x
x1 = self.conv1(x)
x2a = self.conv2a(x1)
x2b = self.conv2b(x1)
x2c = self.conv2c(x1)
x2d = self.conv2d(x1)
x2 = self.gate(x2a) + self.gate(x2b) + self.gate(x2c) + self.gate(x2d)
x3 = self.conv3(x2)
if self.downsample is not None:
identity = self.downsample(identity)
out = x3 + identity
if self.IN is not None:
out = self.IN(out)
return F.relu(out)
##########
# Network architecture
##########
class OSNet(nn.Module):
"""Omni-Scale Network.
Reference:
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
- Zhou et al. Learning Generalisable Omni-Scale Representations
for Person Re-Identification. arXiv preprint, 2019.
Args:
num_classes -> Int: classes number/person ids number
blocks -> List[OSBlock, ..]: OSBlock form basic block for OSNet
layers -> List[Int, ..]: List stores the stack number for OSBlock stack
channels -> List[Int, ..]: List stores the channel number for channel regulation
feature_dim -> Int: if feature_dim=a(Int), deploy fc layer while set a as fc output vector size.
if feature_dim=0 or None, no fc layer deploy.
loss -> Set{String, ..}: String stores loss type for experiment loss regulation
IN -> bool: Use Instance Normalization or not
"""
def __init__(self, num_classes, blocks, layers, channels, feature_dim=512, loss={'xent'}, IN=False, **kwargs):
super(OSNet, self).__init__()
num_blocks = len(blocks)
assert num_blocks == len(layers) # 确保blocks的数量和layers规定数保持一致
assert num_blocks == len(channels) - 1 # 确保channel调整数量和blocks数一致
self.loss = loss
# convolutional backbone
self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=IN) # '723'Conv + bn + relu
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) # 321 maxpool
self.conv2 = self._make_layer(blocks[0], layers[0], channels[0], channels[1], reduce_spatial_size=True, IN=IN)
self.conv3 = self._make_layer(blocks[1], layers[1], channels[1], channels[2], reduce_spatial_size=True)
self.conv4 = self._make_layer(blocks[2], layers[2], channels[2], channels[3], reduce_spatial_size=False)
self.conv5 = Conv1x1(channels[3], channels[3]) # 1x1Conv + bn + relu
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
# fully connected layer
self.fc = self._construct_fc_layer(feature_dim, channels[3], dropout_p=None)
# identity classification layer
self.classifier = nn.Linear(self.feature_dim, num_classes)
self._init_params()
def _make_layer(self, block, layer, in_channels, out_channels, reduce_spatial_size, IN=False):
layers = [] # layers列表用于堆叠block构建网络
layers.append(block(in_channels, out_channels, IN=IN)) #
for i in range(1, layer): #
layers.append(block(out_channels, out_channels, IN=IN))
if reduce_spatial_size: # 如果要减少spatial_size则加入下列层
layers.append(
nn.Sequential(Conv1x1(out_channels, out_channels), nn.AvgPool2d(2, stride=2))
)
return nn.Sequential(*layers)
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
if fc_dims is None or fc_dims < 0:
self.feature_dim = input_dim
return None
if isinstance(fc_dims, int):
fc_dims = [fc_dims]
layers = []
for dim in fc_dims:
layers.append(nn.Linear(input_dim, dim))
layers.append(nn.BatchNorm1d(dim))
layers.append(nn.ReLU(inplace=True))
if dropout_p is not None:
layers.append(nn.Dropout(p=dropout_p))
input_dim = dim
self.feature_dim = fc_dims[-1]
return nn.Sequential(*layers)
def _init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def featuremaps(self, x):
x = self.conv1(x)
x = self.maxpool(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
return x
def forward(self, x, return_featuremaps=False):
x = self.featuremaps(x)
if return_featuremaps:
return x
v = self.global_avgpool(x)
v = v.view(v.size(0), -1)
if self.fc is not None:
v = self.fc(v)
if not self.training:
return v
y = self.classifier(v)
if self.loss == {'xent'}:
return y
elif self.loss == {'xent', 'htri'}:
return y, v
elif self.loss == {'cent'}:
return y, v
elif self.loss == {'ring'}:
return y, v
else:
raise KeyError("Unsupported loss: {}".format(self.loss))
def init_pretrained_weights(model, key=''):
"""Initializes model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
import os
import errno
import gdown
from collections import OrderedDict
def _get_torch_home():
ENV_TORCH_HOME = 'TORCH_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
torch_home = os.path.expanduser(
os.getenv(
ENV_TORCH_HOME,
os.path.join(
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
)
)
)
return torch_home
torch_home = _get_torch_home()
model_dir = os.path.join(torch_home, 'checkpoints')
try:
os.makedirs(model_dir)
except OSError as e:
if e.errno == errno.EEXIST:
# Directory already exists, ignore.
pass
else:
# Unexpected OSError, re-raise.
raise
filename = key + '_imagenet.pth'
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
gdown.download(pretrained_urls[key], cached_file, quiet=False)
state_dict = torch.load(cached_file)
model_dict = model.state_dict()
new_state_dict = OrderedDict()
matched_layers, discarded_layers = [], []
for k, v in state_dict.items():
if k.startswith('module.'):
k = k[7:] # discard module.
if k in model_dict and model_dict[k].size() == v.size():
new_state_dict[k] = v
matched_layers.append(k)
else:
discarded_layers.append(k)
model_dict.update(new_state_dict)
model.load_state_dict(model_dict)
if len(matched_layers) == 0:
warnings.warn(
'The pretrained weights from "{}" cannot be loaded, '
'please check the key names manually '
'(** ignored and continue **)'.format(cached_file)
)
else:
print('Successfully loaded imagenet pretrained weights from "{}"'.format(cached_file))
if len(discarded_layers) > 0:
print(
'** The following layers are discarded '
'due to unmatched keys or layer size: {}'.
format(discarded_layers)
)
##########
# Instantiation
##########
def osnet_x1_0(num_classes=1000, pretrained=False, loss={'xent'}, **kwargs):
# standard size (width x1.0)
model = OSNet(num_classes, blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2],
channels=[64, 256, 384, 512], loss=loss, **kwargs)
if pretrained:
init_pretrained_weights(model, key='osnet_x1_0')
return model
def osnet_x0_75(num_classes=1000, pretrained=True, loss={'xent'}, **kwargs):
# medium size (width x0.75)
model = OSNet(num_classes, blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2],
channels=[48, 192, 288, 384], loss=loss, **kwargs)
if pretrained:
init_pretrained_weights(model, key='osnet_x0_75')
return model
def osnet_x0_5(num_classes=1000, pretrained=True, loss={'xent'}, **kwargs):
# tiny size (width x0.5)
model = OSNet(num_classes, blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2],
channels=[32, 128, 192, 256], loss=loss, **kwargs)
if pretrained:
init_pretrained_weights(model, key='osnet_x0_5')
return model
def osnet_x0_25(num_classes=1000, pretrained=True, loss={'xent'}, **kwargs):
# very tiny size (width x0.25)
model = OSNet(num_classes, blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2],
channels=[16, 64, 96, 128], loss=loss, **kwargs)
if pretrained:
init_pretrained_weights(model, key='osnet_x0_25')
return model
def osnet_ibn_x1_0(
num_classes=1000, pretrained=True, loss={'xent'}, **kwargs):
# standard size (width x1.0) + IBN layer
# Ref: Pan et al. Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net. ECCV, 2018.
model = OSNet(num_classes, blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2],
channels=[64, 256, 384, 512], loss=loss, IN=True, **kwargs)
if pretrained:
init_pretrained_weights(model, key='osnet_ibn_x1_0')
return model
if __name__ == '__main__':