MobileViT是Apple在2021年提出的轻量级视觉Transformer,结合了CNN的高效和Transformer的全局建模能力。CNN的轻量和高效+Transformer的自注意力机制和全局视野。MobileViT通过巧妙地结合MV2的高效性和Transformer的全局建模能力,在移动端设备上实现了出色的性能-效率平衡。
为什么不用纯Transformer架构,前面提到了它很“重”,除此之外还有一些其他的问题,比如说:
1. Vision Transformer结构
unfold和fold操作
2. 模型结构解释
1、MV2 Block (Mobile Inverted Bottleneck)
MV2是MobileNetV2中提出的倒残差结构,是轻量级CNN的核心模块。
MV2结构图解
输入 (H×W×C)
↓
[1x1 Conv] 扩展通道数 (×t)
↓
[BN +激活]
↓
[3x3 Depthwise Conv] 空间特征提取
↓
[BN +激活]
↓
[1x1 Conv] 压缩通道数
↓
[BN]
↓
[残差连接] (如果stride=1且输入输出通道相同)
↓
输出 (H×W×C')
MV2代码实现
import torch
import torch.nn as nn
class MV2Block(nn.Module):
"""
MobileNetV2 Inverted Residual Block
"""
def __init__(self, in_channels, out_channels, stride=1, expansion_factor=4):
super().__init__()
hidden_channels = in_channels * expansion_factor
self.stride = stride
self.use_residual = stride == 1 and in_channels == out_channels
# 扩展层 (1x1卷积)
self.expand_conv = nn.Sequential(
nn.Conv2d(in_channels, hidden_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(hidden_channels),
nn.SiLU() # Swish激活函数
) if expansion_factor != 1 else nn.Identity()
# 深度可分离卷积
self.depthwise_conv = nn.Sequential(
nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3,
stride=stride, padding=1, groups=hidden_channels, bias=False),
nn.BatchNorm2d(hidden_channels),
nn.SiLU()
)
# 投影层 (1x1卷积压缩通道)
self.project_conv = nn.Sequential(
nn.Conv2d(hidden_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
identity = x
# 扩展
x = self.expand_conv(x)
# 深度卷积
x = self.depthwise_conv(x)
# 投影
x = self.project_conv(x)
# 残差连接
if self.use_residual:
x = x + identity
return x
# 使用示例
mv2 = MV2Block(in_channels=32, out_channels=64, stride=2)
x = torch.randn(1, 32, 56, 56)
out = mv2(x)
print(f"输入形状: {x.shape}, 输出形状: {out.shape}")
MV2的关键特点
class MV2Characteristics:
"""MV2块的核心特性"""
def __init__(self):
self.features = {
"倒残差结构": "先扩展通道(×4-6),再压缩,不同于传统残差的先压缩后扩展",
"深度可分离卷积": "将标准卷积分解为depthwise+pointwise,大幅减少计算量",
"线性瓶颈": "最后一层不使用激活函数,保留更多信息",
"残差连接": "当stride=1时使用,帮助梯度流动"
}
def compute_complexity(self, in_c, out_c, h, w):
"""计算计算量对比"""
expansion = 4
hidden_c = in_c * expansion
# 标准卷积计算量
standard_flops = h * w * in_c * out_c * 3 * 3
# MV2计算量
mv2_flops = (h * w * in_c * hidden_c * 1 * 1) + \ # 扩展层
(h * w * hidden_c * 3 * 3) + \ # depthwise
(h * w * hidden_c * out_c * 1 * 1) # 投影层
return {
"standard_flops": standard_flops,
"mv2_flops": mv2_flops,
"reduction_ratio": standard_flops / mv2_flops
}
2、MobileViT Block (核心创新)
1. MobileViT Block结构图解
输入 (H×W×C)
↓
[3x3 Conv] 局部特征提取
↓
[1x1 Conv] 升维 (C → d)
↓
[展开为序列] (H×W → N×d, 其中N=H×W)
↓
[Transformer] × L层
├── Layer Norm
├── Multi-Head Self-Attention
├── Layer Norm
└── MLP
↓
[重塑为图像] (N×d → H×W×d)
↓
[1x1 Conv] 降维 (d → C)
↓
[拼接] 与原始输入
↓
[3x3 Conv] 融合特征
↓
输出 (H×W×C)
2. MobileViT Block代码实现
class MobileViTBlock(nn.Module):
"""
MobileViT的核心模块:融合CNN和Transformer
"""
def __init__(self, in_channels, transformer_dim, ffn_dim,
num_heads=4, num_transformer_blocks=2, patch_size=2):
super().__init__()
self.patch_size = patch_size
self.transformer_dim = transformer_dim
# 局部特征提取 (CNN部分)
self.local_rep = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=3,
padding=1, groups=in_channels, bias=False), # Depthwise
nn.BatchNorm2d(in_channels),
nn.Conv2d(in_channels, transformer_dim, kernel_size=1, bias=False), # Pointwise
nn.BatchNorm2d(transformer_dim),
nn.SiLU()
)
# Transformer部分 (全局建模)
self.transformer = nn.ModuleList([
TransformerBlock(transformer_dim, ffn_dim, num_heads)
for _ in range(num_transformer_blocks)
])
# 融合层
self.fusion = nn.Sequential(
nn.Conv2d(transformer_dim + in_channels, in_channels,
kernel_size=1, bias=False),
nn.BatchNorm2d(in_channels),
nn.SiLU(),
nn.Conv2d(in_channels, in_channels, kernel_size=3,
padding=1, groups=in_channels, bias=False),
nn.BatchNorm2d(in_channels),
nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(in_channels),
nn.SiLU()
)
def unfolding(self, x):
"""
将特征图转换为patches序列
输入: [B, C, H, W]
输出: [B, C, num_patches, patch_h, patch_w]
"""
B, C, H, W = x.shape
p = self.patch_size
# 确保H,W能被p整除
H_new = (H + p - 1) // p * p
W_new = (W + p - 1) // p * p
if H != H_new or W != W_new:
x = F.interpolate(x, size=(H_new, W_new), mode='bilinear')
# 重塑为patches
x = x.reshape(B, C, H_new // p, p, W_new // p, p)
x = x.permute(0, 2, 4, 3, 5, 1) # [B, nh, nw, p, p, C]
x = x.reshape(B, -1, p * p, C) # [B, num_patches, p*p, C]
return x, (H_new, W_new)
def folding(self, x, original_size):
"""
将patches序列还原为特征图
"""
B, num_patches, patch_area, C = x.shape
p = self.patch_size
H, W = original_size
nh = H // p
nw = W // p
x = x.reshape(B, nh, nw, p, p, C)
x = x.permute(0, 5, 1, 3, 2, 4) # [B, C, nh, p, nw, p]
x = x.reshape(B, C, H, W)
return x
def forward(self, x):
identity = x
# 1. 局部特征提取 (CNN)
x_local = self.local_rep(x) # [B, transformer_dim, H, W]
# 2. 转换为序列并应用Transformer
x_patches, size = self.unfolding(x_local) # [B, num_patches, patch_area, C]
B, num_patches, patch_area, C = x_patches.shape
# 将patch_area维度和batch合并,以便应用Transformer
x_patches = x_patches.reshape(B * num_patches, patch_area, C)
# 应用Transformer blocks
for transformer in self.transformer:
x_patches = transformer(x_patches)
# 重塑回原形状
x_patches = x_patches.reshape(B, num_patches, patch_area, C)
# 3. 还原为特征图
x_global = self.folding(x_patches, size)
# 调整尺寸到原始输入尺寸
if x_global.shape[-2:] != identity.shape[-2:]:
x_global = F.interpolate(x_global, size=identity.shape[-2:],
mode='bilinear')
# 4. 拼接局部和全局特征
x_cat = torch.cat([identity, x_global], dim=1)
# 5. 特征融合
out = self.fusion(x_cat)
return out
class TransformerBlock(nn.Module):
"""简化的Transformer块"""
def __init__(self, dim, ffn_dim, num_heads, dropout=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout,
batch_first=True)
self.norm2 = nn.LayerNorm(dim)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(ffn_dim, dim),
nn.Dropout(dropout)
)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# Self-Attention with residual
x = x + self.dropout(self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0])
# FFN with residual
x = x + self.dropout(self.ffn(self.norm2(x)))
return x
3、MV2和MobileViT Block的对比
2. 计算量对比
1. 结构对比表
def compare_complexity():
"""
比较MV2和MobileViT Block的计算量
"""
input_size = (1, 32, 56, 56) # batch, channels, height, width
# MV2 Block
mv2 = MV2Block(32, 64)
# MobileViT Block
mobilevit = MobileViTBlock(
in_channels=32,
transformer_dim=96,
ffn_dim=192,
num_heads=4,
num_transformer_blocks=2
)
# 计算FLOPs (简化版)
def count_flops(module, input_size):
# 实际应用中会使用thop库
pass
print("MV2 Block: 计算量小,适合浅层特征提取")
print("MobileViT Block: 计算量大,适合深层语义建模")
4、MobileViT的整体架构
class MobileViT(nn.Module):
"""
完整的MobileViT架构
"""
def __init__(self, num_classes=1000):
super().__init__()
# 初始卷积层
self.stem = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(16),
nn.SiLU()
)
# Stage 1: MV2 blocks (下采样)
self.stage1 = nn.Sequential(
MV2Block(16, 32, stride=1), # 56x56
MV2Block(32, 32, stride=1),
MV2Block(32, 64, stride=2), # 28x28
MV2Block(64, 64, stride=1)
)
# Stage 2: MV2 + MobileViT blocks
self.stage2 = nn.Sequential(
MV2Block(64, 96, stride=2), # 14x14
MobileViTBlock(96, 120, 240, num_heads=4, num_transformer_blocks=2),
MobileViTBlock(96, 120, 240, num_heads=4, num_transformer_blocks=2)
)
# Stage 3: MV2 + MobileViT blocks
self.stage3 = nn.Sequential(
MV2Block(96, 128, stride=2), # 7x7
MobileViTBlock(128, 160, 320, num_heads=4, num_transformer_blocks=3),
MobileViTBlock(128, 160, 320, num_heads=4, num_transformer_blocks=3)
)
# 分类头
self.pool = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(128, num_classes)
)
def forward(self, x):
x = self.stem(x) # [B,16,112,112]
x = self.stage1(x) # [B,64,28,28]
x = self.stage2(x) # [B,96,14,14]
x = self.stage3(x) # [B,128,7,7]
x = self.pool(x) # [B,128,1,1]
x = x.flatten(1) # [B,128]
x = self.classifier(x) # [B,num_classes]
return x
5、MV2 vs MobileViT的设计哲学
class DesignPhilosophy:
"""两种模块的设计哲学"""
def mv2_philosophy(self):
return """
MV2的设计哲学:
- 极致轻量化:用深度可分离卷积替代标准卷积
- 信息保留:线性瓶颈避免信息损失
- 特征复用:倒残差结构让信息流动更高效
- 适合移动端:计算量小,速度快
"""
def mobilevit_philosophy(self):
return """
MobileViT的设计哲学:
- CNN+Transformer融合:结合CNN的局部归纳偏置和Transformer的全局建模
- 轻量级Transformer:在小patch上应用Transformer,避免长序列
- 特征重用:通过拼接和融合充分利用不同层次的特征
- 即插即用:可以替换标准卷积块,提升性能而不显著增加计算量
"""
def when_to_use(self):
return {
"使用MV2的场景": [
"网络浅层(大尺寸特征图)",
"计算资源极度受限",
"实时性要求极高",
"只需要局部特征的任务"
],
"使用MobileViT的场景": [
"网络深层(小尺寸特征图)",
"需要全局上下文的任务",
"中等计算资源",
"精度要求高的任务"
]
}
6、与纯CNN/Transformer的对比
3. 模型详细配置
4. MobileViT v1 vs v2 vs v3 全面对比
MobileViT v1 (2021) MobileViT v2 (2022) MobileViT v3 (2023)
↓ ↓ ↓
[CNN Stem] [CNN Stem] [Efficient Stem]
↓ ↓ ↓
[MV2 + MobileViT] [MV2 + MobileViTv2] [Efficient Block]
↓ ↓ ↓
[全局建模] [可分离自注意力] [融合MBConv + ViT]
↓ ↓ ↓
[分类头] [分类头] [分类头]