详解MobileViTv1v2v3网络结构

7 阅读7分钟

MobileViT是Apple在2021年提出的轻量级视觉Transformer,结合了CNN的高效和Transformer的全局建模能力。CNN的轻量和高效+Transformer的自注意力机制和全局视野。MobileViT通过巧妙地结合MV2的高效性和Transformer的全局建模能力,在移动端设备上实现了出色的性能-效率平衡。

为什么不用纯Transformer架构,前面提到了它很“重”,除此之外还有一些其他的问题,比如说:

1. Vision Transformer结构

09bcbb9f-ecb2-4a84-8347-d5fe7e9c8d46.png

9d024c7f-c5e0-481d-9e5a-39253eb9645e.png

unfold和fold操作

2. 模型结构解释

1、MV2 Block (Mobile Inverted Bottleneck)

MV2是MobileNetV2中提出的倒残差结构,是轻量级CNN的核心模块。

MV2结构图解

输入 (H×W×C)
    ↓
[1x1 Conv] 扩展通道数 (×t)
    ↓
[BN +激活]
    ↓
[3x3 Depthwise Conv] 空间特征提取
    ↓
[BN +激活]
    ↓
[1x1 Conv] 压缩通道数
    ↓
[BN]
    ↓
[残差连接] (如果stride=1且输入输出通道相同)
    ↓
输出 (H×W×C')

MV2代码实现

import torch
import torch.nn as nn

class MV2Block(nn.Module):
    """
    MobileNetV2 Inverted Residual Block
    """
    def __init__(self, in_channels, out_channels, stride=1, expansion_factor=4):
        super().__init__()
        
        hidden_channels = in_channels * expansion_factor
        self.stride = stride
        self.use_residual = stride == 1 and in_channels == out_channels
        
        # 扩展层 (1x1卷积)
        self.expand_conv = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(hidden_channels),
            nn.SiLU()  # Swish激活函数
        ) if expansion_factor != 1 else nn.Identity()
        
        # 深度可分离卷积
        self.depthwise_conv = nn.Sequential(
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, 
                     stride=stride, padding=1, groups=hidden_channels, bias=False),
            nn.BatchNorm2d(hidden_channels),
            nn.SiLU()
        )
        
        # 投影层 (1x1卷积压缩通道)
        self.project_conv = nn.Sequential(
            nn.Conv2d(hidden_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )
    
    def forward(self, x):
        identity = x
        
        # 扩展
        x = self.expand_conv(x)
        
        # 深度卷积
        x = self.depthwise_conv(x)
        
        # 投影
        x = self.project_conv(x)
        
        # 残差连接
        if self.use_residual:
            x = x + identity
        
        return x

# 使用示例
mv2 = MV2Block(in_channels=32, out_channels=64, stride=2)
x = torch.randn(1, 32, 56, 56)
out = mv2(x)
print(f"输入形状: {x.shape}, 输出形状: {out.shape}")

MV2的关键特点

class MV2Characteristics:
    """MV2块的核心特性"""
    
    def __init__(self):
        self.features = {
            "倒残差结构": "先扩展通道(×4-6),再压缩,不同于传统残差的先压缩后扩展",
            "深度可分离卷积": "将标准卷积分解为depthwise+pointwise,大幅减少计算量",
            "线性瓶颈": "最后一层不使用激活函数,保留更多信息",
            "残差连接": "当stride=1时使用,帮助梯度流动"
        }
    
    def compute_complexity(self, in_c, out_c, h, w):
        """计算计算量对比"""
        expansion = 4
        hidden_c = in_c * expansion
        
        # 标准卷积计算量
        standard_flops = h * w * in_c * out_c * 3 * 3
        
        # MV2计算量
        mv2_flops = (h * w * in_c * hidden_c * 1 * 1) + \  # 扩展层
                    (h * w * hidden_c * 3 * 3) + \           # depthwise
                    (h * w * hidden_c * out_c * 1 * 1)       # 投影层
        
        return {
            "standard_flops": standard_flops,
            "mv2_flops": mv2_flops,
            "reduction_ratio": standard_flops / mv2_flops
        }

2、MobileViT Block (核心创新)

1. MobileViT Block结构图解

输入 (H×W×C)
    ↓
[3x3 Conv] 局部特征提取
    ↓
[1x1 Conv] 升维 (C → d)
    ↓
[展开为序列] (H×W → N×d, 其中N=H×W)
    ↓
[Transformer] × L层
    ├── Layer Norm
    ├── Multi-Head Self-Attention
    ├── Layer Norm  
    └── MLP
    ↓
[重塑为图像] (N×d → H×W×d)
    ↓
[1x1 Conv] 降维 (d → C)
    ↓
[拼接] 与原始输入
    ↓
[3x3 Conv] 融合特征
    ↓
输出 (H×W×C)

2. MobileViT Block代码实现

class MobileViTBlock(nn.Module):
    """
    MobileViT的核心模块:融合CNN和Transformer
    """
    def __init__(self, in_channels, transformer_dim, ffn_dim, 
                 num_heads=4, num_transformer_blocks=2, patch_size=2):
        super().__init__()
        
        self.patch_size = patch_size
        self.transformer_dim = transformer_dim
        
        # 局部特征提取 (CNN部分)
        self.local_rep = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, 
                     padding=1, groups=in_channels, bias=False),  # Depthwise
            nn.BatchNorm2d(in_channels),
            nn.Conv2d(in_channels, transformer_dim, kernel_size=1, bias=False),  # Pointwise
            nn.BatchNorm2d(transformer_dim),
            nn.SiLU()
        )
        
        # Transformer部分 (全局建模)
        self.transformer = nn.ModuleList([
            TransformerBlock(transformer_dim, ffn_dim, num_heads)
            for _ in range(num_transformer_blocks)
        ])
        
        # 融合层
        self.fusion = nn.Sequential(
            nn.Conv2d(transformer_dim + in_channels, in_channels, 
                     kernel_size=1, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, in_channels, kernel_size=3,
                     padding=1, groups=in_channels, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.SiLU()
        )
    
    def unfolding(self, x):
        """
        将特征图转换为patches序列
        输入: [B, C, H, W]
        输出: [B, C, num_patches, patch_h, patch_w]
        """
        B, C, H, W = x.shape
        p = self.patch_size
        
        # 确保H,W能被p整除
        H_new = (H + p - 1) // p * p
        W_new = (W + p - 1) // p * p
        
        if H != H_new or W != W_new:
            x = F.interpolate(x, size=(H_new, W_new), mode='bilinear')
        
        # 重塑为patches
        x = x.reshape(B, C, H_new // p, p, W_new // p, p)
        x = x.permute(0, 2, 4, 3, 5, 1)  # [B, nh, nw, p, p, C]
        x = x.reshape(B, -1, p * p, C)    # [B, num_patches, p*p, C]
        
        return x, (H_new, W_new)
    
    def folding(self, x, original_size):
        """
        将patches序列还原为特征图
        """
        B, num_patches, patch_area, C = x.shape
        p = self.patch_size
        H, W = original_size
        
        nh = H // p
        nw = W // p
        
        x = x.reshape(B, nh, nw, p, p, C)
        x = x.permute(0, 5, 1, 3, 2, 4)  # [B, C, nh, p, nw, p]
        x = x.reshape(B, C, H, W)
        
        return x
    
    def forward(self, x):
        identity = x
        
        # 1. 局部特征提取 (CNN)
        x_local = self.local_rep(x)  # [B, transformer_dim, H, W]
        
        # 2. 转换为序列并应用Transformer
        x_patches, size = self.unfolding(x_local)  # [B, num_patches, patch_area, C]
        B, num_patches, patch_area, C = x_patches.shape
        
        # 将patch_area维度和batch合并,以便应用Transformer
        x_patches = x_patches.reshape(B * num_patches, patch_area, C)
        
        # 应用Transformer blocks
        for transformer in self.transformer:
            x_patches = transformer(x_patches)
        
        # 重塑回原形状
        x_patches = x_patches.reshape(B, num_patches, patch_area, C)
        
        # 3. 还原为特征图
        x_global = self.folding(x_patches, size)
        
        # 调整尺寸到原始输入尺寸
        if x_global.shape[-2:] != identity.shape[-2:]:
            x_global = F.interpolate(x_global, size=identity.shape[-2:], 
                                     mode='bilinear')
        
        # 4. 拼接局部和全局特征
        x_cat = torch.cat([identity, x_global], dim=1)
        
        # 5. 特征融合
        out = self.fusion(x_cat)
        
        return out

class TransformerBlock(nn.Module):
    """简化的Transformer块"""
    def __init__(self, dim, ffn_dim, num_heads, dropout=0.1):
        super().__init__()
        
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout,
                                         batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        
        self.ffn = nn.Sequential(
            nn.Linear(dim, ffn_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(ffn_dim, dim),
            nn.Dropout(dropout)
        )
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        # Self-Attention with residual
        x = x + self.dropout(self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0])
        
        # FFN with residual
        x = x + self.dropout(self.ffn(self.norm2(x)))
        
        return x

3、MV2和MobileViT Block的对比

016a9ca8-a4b9-4171-b1ba-21a40bbd1233.png

2. 计算量对比

1. 结构对比表

def compare_complexity():
    """
    比较MV2和MobileViT Block的计算量
    """
    input_size = (1, 32, 56, 56)  # batch, channels, height, width
    
    # MV2 Block
    mv2 = MV2Block(32, 64)
    
    # MobileViT Block
    mobilevit = MobileViTBlock(
        in_channels=32,
        transformer_dim=96,
        ffn_dim=192,
        num_heads=4,
        num_transformer_blocks=2
    )
    
    # 计算FLOPs (简化版)
    def count_flops(module, input_size):
        # 实际应用中会使用thop库
        pass
    
    print("MV2 Block: 计算量小,适合浅层特征提取")
    print("MobileViT Block: 计算量大,适合深层语义建模")

4、MobileViT的整体架构

class MobileViT(nn.Module):
    """
    完整的MobileViT架构
    """
    def __init__(self, num_classes=1000):
        super().__init__()
        
        # 初始卷积层
        self.stem = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.SiLU()
        )
        
        # Stage 1: MV2 blocks (下采样)
        self.stage1 = nn.Sequential(
            MV2Block(16, 32, stride=1),  # 56x56
            MV2Block(32, 32, stride=1),
            MV2Block(32, 64, stride=2),   # 28x28
            MV2Block(64, 64, stride=1)
        )
        
        # Stage 2: MV2 + MobileViT blocks
        self.stage2 = nn.Sequential(
            MV2Block(64, 96, stride=2),    # 14x14
            MobileViTBlock(96, 120, 240, num_heads=4, num_transformer_blocks=2),
            MobileViTBlock(96, 120, 240, num_heads=4, num_transformer_blocks=2)
        )
        
        # Stage 3: MV2 + MobileViT blocks
        self.stage3 = nn.Sequential(
            MV2Block(96, 128, stride=2),    # 7x7
            MobileViTBlock(128, 160, 320, num_heads=4, num_transformer_blocks=3),
            MobileViTBlock(128, 160, 320, num_heads=4, num_transformer_blocks=3)
        )
        
        # 分类头
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x):
        x = self.stem(x)      # [B,16,112,112]
        x = self.stage1(x)    # [B,64,28,28]
        x = self.stage2(x)    # [B,96,14,14]
        x = self.stage3(x)    # [B,128,7,7]
        
        x = self.pool(x)      # [B,128,1,1]
        x = x.flatten(1)      # [B,128]
        x = self.classifier(x) # [B,num_classes]
        
        return x

5、MV2 vs MobileViT的设计哲学

class DesignPhilosophy:
    """两种模块的设计哲学"""
    
    def mv2_philosophy(self):
        return """
        MV2的设计哲学:
        - 极致轻量化:用深度可分离卷积替代标准卷积
        - 信息保留:线性瓶颈避免信息损失
        - 特征复用:倒残差结构让信息流动更高效
        - 适合移动端:计算量小,速度快
        """
    
    def mobilevit_philosophy(self):
        return """
        MobileViT的设计哲学:
        - CNN+Transformer融合:结合CNN的局部归纳偏置和Transformer的全局建模
        - 轻量级Transformer:在小patch上应用Transformer,避免长序列
        - 特征重用:通过拼接和融合充分利用不同层次的特征
        - 即插即用:可以替换标准卷积块,提升性能而不显著增加计算量
        """
    
    def when_to_use(self):
        return {
            "使用MV2的场景": [
                "网络浅层(大尺寸特征图)",
                "计算资源极度受限",
                "实时性要求极高",
                "只需要局部特征的任务"
            ],
            "使用MobileViT的场景": [
                "网络深层(小尺寸特征图)",
                "需要全局上下文的任务",
                "中等计算资源",
                "精度要求高的任务"
            ]
        }

6、与纯CNN/Transformer的对比

e5f0352d-a9c2-4a0a-a95d-218b4aebedad.png

3. 模型详细配置

4. MobileViT v1 vs v2 vs v3 全面对比

MobileViT v1 (2021)          MobileViT v2 (2022)          MobileViT v3 (2023)
     ↓                               ↓                               ↓
[CNN Stem]                    [CNN Stem]                    [Efficient Stem]
     ↓                               ↓                               ↓
[MV2 + MobileViT]            [MV2 + MobileViTv2]          [Efficient Block]
     ↓                               ↓                               ↓
[全局建模]                     [可分离自注意力]               [融合MBConv + ViT]
     ↓                               ↓                               ↓
[分类头]                       [分类头]                       [分类头]

4d2487bc-b829-4a66-9304-b78ce18bb43a.png

MobileViT模型简介_mobile vit for str