基于Network Slimming的YOLOv8结构化剪枝

13 阅读16分钟

8f4bd4d4-77cb-4c45-a7c3-e7f08325e648.png

27fdb66d-fcaf-44f5-9052-96970b17b7aa.png

1. 核心思想

Network Slimming 的核心是利用Batch Normalization(BN)层中的缩放因子γ(gamma参数) 作为通道重要性的度量指标,用L1范数将BN的放缩因子推向0,对不重要的通道channel进行剪枝。

2. 基本原理

2.1 BN层的缩放因子

85dc4858-f8cd-4fdc-9b58-9abc6c93e4a4.png

其中γ是可学习的缩放参数

  • γ的大小与对应通道的重要性相关
  • γ值小的通道对网络输出的贡献小,可以将γ值小的通道剪掉
  • gamma值大 → 通道重要(缩放因子大,对输出贡献大)
  • gamma值接近0 → 通道不重要(对输出贡献小)

公式含义:

2.2 L1稀疏化训练

为了让γ更好地反映通道重要性,在训练时对γ施加L1正则化

71ed0e5f-631f-43fe-a4a1-3cb9c0672b9f.png

其中:

  • L_original:原始损失函数
  • λ稀疏化惩罚系数,平衡任务性能模型稀疏度
  • ∑|γ|:所有BN层γ参数的L1范数
# 关键理解点:
1. λ ≈ 对γ参数的惩罚强度
2. λ越大 → γ越容易被推到0 → 更稀疏但可能欠拟合
3. λ越小 → γ保持原始分布 → 更密集但剪枝困难
4. 最优λ = 刚好让不重要通道γ≈0,重要通道γ保持较大

公式含义:

3. 具体步骤

准备工作

(1)划分训练集和验证集

(2)迁移学习预训练模型权重文件

(3)数据集配置文件:cfg/datasets/VOC.yaml

(4)train.py、稀疏训练文件sparsity_train.py、模型剪枝文件、剪枝后微调文件

【正确流程】
yolov8s.pt (COCO)
    ↓
[正常训练] ← 在你的数据集上训练到收敛(重点!)
    ↓
best.pt (你的任务) ← γ已适应你的数据
    ↓
[稀疏化训练] ← 只做稀疏化,不做领域适应
    ↓
γ分布理想(重要通道大,不重要通道被压制)
    ↓
剪枝效果好
    ↓
精度保持好

3.1 阶段一:正常训练自己的数据集

python train.py 或者

yolo detect train data=D:/ultralytics/ultralytics/cfg/datasets/VOC-ball.yaml  
model=D:/ultralytics/ultralytics/weights/yolov8s.pt epochs=100 imgsz=640 batch=16  
workers=4

得到best.pt

3.2 阶段二:稀疏化训练

(1)稀疏化训练时怎么设置λ值(L1稀疏化正则化惩罚系数)?

需要根据数据集调整,可以通过观察tensorboard的map,缩放因子gamma变化直方图等选择。

训练中需要通过tensorboard监控训练过程,特别是map变化,bn分布变化等,在runs/train/目录下有events.out.tfevents.* 文件。

5f036c83-97f7-457f-b6cd-be334efb7578.png

3.2.1 sr设置的黄金经验法则

1. 根据数据集规模调整

def suggest_sr_by_dataset_size(train_size):
    """根据训练集大小推荐sr"""
    if train_size < 1000:      # 小数据集
        return 0.0005
    elif train_size < 10000:   # 中数据集
        return 0.001
    else:                      # 大数据集
        return 0.0015

2. 根据类别数调整

def adjust_sr_by_classes(base_sr, num_classes):
    """类别越多越保守"""
    if num_classes > 50:
        return base_sr * 0.8   # 多类别,减小sr
    elif num_classes < 10:
        return base_sr * 1.2   # 少类别,可增大sr
    return base_sr

3. 根据目标剪枝率调整

def suggest_sr_by_prune_ratio(prune_ratio):
    """
    根据目标剪枝率推荐sr
    prune_ratio: 期望剪枝的比例 (0.3 = 剪掉30%通道)
    """
    if prune_ratio < 0.3:      # 轻度剪枝
        return 0.0005
    elif prune_ratio < 0.5:    # 中度剪枝
        return 0.001
    elif prune_ratio < 0.7:    # 重度剪枝
        return 0.0015
    else:                      # 极限剪枝
        return 0.002           # 但建议迭代剪枝

3.2.2 基于监控的动态调整策略

关键监控指标

监控指标 = {
    "gamma分布": "观察直方图是否出现双峰",
    "稀疏度": "γ < 1e-3的通道比例",
    "mAP变化": "精度是否明显下降",
    "L1损失占比": "L1损失/总损失的比例"
}

9ca49a16-fd9e-4cb2-99ba-2661380ffa2a.png

开始训练 (sr=0.001)
    ↓
监控训练过程
    ↓
检查gamma直方图
    ├──→ 全是小值(>80%接近0) → sr太大 → 减小到0.0003
    ├──→ 全是大值(<20%接近0) → sr太小 → 增大到0.002
    └──→ 双峰分布(30-70%接近0) → 继续监控
        ↓
检查mAP
    ├──→ mAP下降明显 → sr可能偏大 → 适当减小
    └──→ mAP稳定 → sr合适

(2)为什么不直接使用yolov8s.pt进行稀疏化训练?

不要直接用yolov8s.pt(官方预训练权重),而要用在你自己的数据集上正常训练完成的权重进行稀疏化训练。因为自己训练好的best.pt权重的 ①γ分布已经适应你的特定任务②收敛速度快(已在最优点)、③稀疏效果更好、④精度更高

方案A(从官方权重)的γ分布:
    训练前:     ████████░░░░░░░░  均值0.89,稀疏度0%
    稀疏化后:   ████░░░░░░░░░░░░  均值0.32,稀疏度68%
    问题:许多重要通道也被压制,因为领域不适应

方案B(从自己best)的γ分布:
    训练前:     ██████░░░░░░░░░░  均值0.72,稀疏度2%
    稀疏化后:   ██░░░░░░░░░░░░░░  均值0.18,稀疏度72%
    优势:真正不重要的通道被压制,重要通道保留

3.2.3 训练代码sparsity_train.py

python sparsity_train.py 得到sparsity_best.pt

import argparse
from ultralytics import YOLO

def train_model(opt):
    # Load a model using the path specified in opt.weights
    model = YOLO(opt.weights)

    # Train the model with parameters specified in the opt argument
    results = model.train(data=opt.data, 
                          epochs=opt.epochs, 
                          imgsz=opt.imgsz, 
                          batch=opt.batch, 
                          workers=opt.workers, 
                          device=opt.device, 
                          sr=opt.sr,
                          name=opt.name)
    return results

def parse_opt():
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', type=str, default='权重', help='path to model weights')
    parser.add_argument('--data', type=str, default='数据集配置文件', help='dataset.yaml path')
    parser.add_argument('--epochs', type=int, default=100, help='number of epochs')
    parser.add_argument('--imgsz', type=int, default=640, help='input image size')
    parser.add_argument('--batch', type=int, default=16, help='batch size')
    parser.add_argument('--workers', type=int, default=4, help='number of workers for dataloader')
    parser.add_argument('--device', nargs='+', type=int, default=[0], help='device id(s) for training, e.g., 0 or 0 1 2 3')
    parser.add_argument('--sr', type=float, default=0.02, help='L1 regularization penalty coefficient (sparsity regularization)')
    parser.add_argument('--name', type=str, default='train-sparse', help='save to project/name')

    opt = parser.parse_args()
    return opt

if __name__ == '__main__':
    opt = parse_opt()
    results = train_model(opt)

3.3 阶段三:网络模型剪枝

python prune.py

3.3.1 prune.py核心函数详解

整体流程概览

剪枝流程 = [
    "Step 1: 准备数据 - 收集BN层信息",
    "Step 2: 计算阈值 - 基于gamma分布确定剪枝阈值",
    "Step 3: 重写配置 - 创建剪枝后的模型结构",
    "Step 4: 生成掩码 - 确定哪些通道保留",
    "Step 5: 构建模型 - 创建剪枝后的空模型",
    "Step 6: 参数赋值 - 将保留的参数复制到新模型",
    "Step 7: 保存模型 - 保存剪枝后的模型"
]

核心函数详解

prepare_prune_data - 准备剪枝数据

def prepare_prune_data(model):
    """
    功能:收集BN层信息,识别特殊结构
    
    返回值:
    - bn_dict: 所有可剪枝的BN层
    - ignore_bn_list: 不能剪枝的BN层(残差连接)
    - chunk_bn_list: 需要偶数通道的BN层(C2f的chunk操作)
    - sorted_bn: 排序后的gamma值,用于确定阈值
    """

compute_prune_threshold - 计算剪枝阈值

def compute_prune_threshold(bn_dict, sorted_bn, prune_ratio):
    """
    功能:计算安全的剪枝阈值
    
    三个关键值:
    1. highest_thre: 最大安全阈值
       - 取所有BN层gamma最大值的最大值
       - 超过这个值会导致某一层被完全剪掉
    
    2. percent_limit: 最大安全剪枝比例
       - 对应highest_thre的百分位数
       - 剪枝比例不能超过这个值
    
    3. thre: 实际使用的阈值
       - 根据用户指定的prune_ratio计算
       - thre = sorted_bn[int(len(sorted_bn) * prune_ratio)]
    
    示例输出:
    Suggested Gamma threshold should be less than 0.5234, yours is 0.1567
    The corresponding prune ratio should be less than 0.850, yours is 0.300
    """

compute_bn_mask - 计算BN层掩码

def compute_bn_mask(model, ignore_bn_list, chunk_bn_list, thre):
    """
    功能:为每个BN层生成保留通道的掩码
    
    掩码生成逻辑:
    mask = gamma > thre ? 1 : 0
    
    特殊处理:
    1. 忽略剪枝的层:mask全为1
    2. C2f的chunk操作:确保保留通道数为偶数
       如果mask.sum()是奇数,降低阈值直到得到偶数
    
    输出表格示例:
    ================================================================================
    |        layer name          |        origin channels |        remaining channels |
    ================================================================================
    | model.0.conv.bn            |        64               |        42                |
    | model.1.conv.bn            |        128              |        85                |
    | model.2.cv1.bn             |        128              |        86                |  # 偶数!
    | model.2.cv2.bn             |        128              |        86                |
    | ...                        |        ...              |        ...               |
    ================================================================================
    """

assign_pruned_params - 参数赋值(最复杂的部分)

def assign_pruned_params(model, pruned_model, maskbndict):
    """
    功能:将原始模型的参数复制到剪枝后的模型
    
    关键数据结构:current_to_prev
    - 定义在DetectionModelPruned中
    - 记录当前层和前一层的依赖关系
    - 例如:Conv层需要前一层的mask来确定输入通道
    """

参数重映射的4种情况:

# 情况1: 普通ConvConv层需要:
- 输出通道掩码:来自本层的BN层
- 输入通道掩码:来自前一层的BN层

# 情况2: Detect模块的最后一层(无BN)
特殊处理:只有输入通道需要掩码
if pattern_detect.fullmatch(name_org):
    in_channels_mask = maskbndict[prev_bn_layer_name]
    module_pruned.weight = module_org.weight[:, in_channels_mask, :, :]

# 情况3: C2f中Bottleneck的第一个卷积
需要从chunk后的第二部分取掩码
if pattern_c2f.fullmatch(currnet_bn_layer_name):
    in_channels_mask = in_channels_mask.chunk(2, 0)[1]  # 取后半部分

# 情况4: SPPF中的特殊卷积
需要复制4份掩码(因为SPPF4个分支)
if name_org == "model.9.cv2.conv":
    in_channels_mask = torch.cat([in_channels_mask for _ in range(4)], dim=0)

参数复制过程可视化:

原始Conv层权重 [out_channels, in_channels, k, k]
    │
    ├─ 按out_channels_mask选择保留的输出通道
    │   ↓
    [保留的out_channels, in_channels, k, k]
    │
    ├─ 按in_channels_mask选择保留的输入通道
    │   ↓
    [保留的out_channels, 保留的in_channels, k, k]
    │
    └─ 赋值给剪枝后的Conv

3.3.2 特殊结构处理

C2f模块的chunk操作

# C2f forward中的关键代码
def forward(self, x):
    y = list(self.cv1(x).chunk(2, 1))  # 在通道维度chunk成2份
    # 要求:cv1的输出通道数必须是偶数!
    
# 剪枝时的处理
if name in chunk_bn_list and mask.sum() % 2 == 1:
    # 如果剪枝后剩余通道是奇数
    # 调整阈值直到得到偶数
    flattened_sorted_weight = torch.sort(module.weight.data.abs().view(-1))[0]
    idx = torch.min(torch.nonzero(flattened_sorted_weight.gt(thre))).item()
    thre_ = flattened_sorted_weight[idx - 1] - 1e-6  # 略低于阈值
    mask = module.weight.data.abs().gt(thre_).float()

残差连接的约束

# 残差连接要求:两个分支的输出通道数必须相等
if module.add:  # 有残差连接的Bottleneck
    ignore_bn_list.append(f"{name[:-4]}.cv1.bn")  # 第一个分支的BN
    ignore_bn_list.append(f"{name}.cv2.bn")       # 第二个分支的BN
    # 这两个层都不参与剪枝,保持通道数一致

3.3.3 剪枝前后的模型变化

模块替换

# 原始模块 -> 剪枝后模块
C2f          → C2fPruned
SPPFSPPFPruned
DetectDetectPruned
DetectionModelDetectionModelPruned

# 剪枝后模块的主要变化:
1. 添加了通道数记录功能
2. 支持根据mask动态调整通道数
3. 保存了层间依赖关系(current_to_prev)

通道数变化示例

Layer             原始通道 → 剪枝后通道
----------------------------------------
Conv1              6442
Conv2              12885  
C2f.cv1            12886  (调整为偶数)
C2f.cv2            12886
...                512312
Detect             256168

3.3.4 剪枝常见问题

d3fe3770-6c4e-40ce-9138-c4b21124b984.png

3.3.5 BN层分类与剪枝可行性

3.3.5.1 从代码逻辑看BN层分类
# 代码中的三类BN层
bn_dict = {}           # 所有BN层
ignore_bn_list = []    # 不能剪枝的BN层(残差连接)
chunk_bn_list = []     # 需要保持偶数的BN层(可剪枝但有约束)

# 实际可剪枝的BN层 = bn_dict - ignore_bn_list
# 注意:chunk_bn_list中的层仍然可以剪枝,只是需要保证偶数

(1)为什么残差连接的BN层不能剪枝?

残差连接输出的通道数由BN层的通道数决定,剪枝会改变conv的输出通道数导致x和F(x)无法相加。

# 假设我们错误地剪枝了残差连接的BN层
原始情况:
输入x形状: [B, 128, H, W]
cv1输出: [B, 128, H, W]  # 经过cv1(包含BN)后的特征
cv2输出: [B, 128, H, W]  # 经过cv2(包含BN)后的特征
最终输出: x + cv2输出 = [B, 128, H, W]  # 形状相同才能相加

# 错误剪枝后(只剪cv2的BN,不剪cv1的BN):
输入x形状: [B, 128, H, W]  # 不变
cv1输出: [B, 128, H, W]    # cv1的BN没剪,通道数不变
cv2输出: [B, 64, H, W]     # cv2的BN剪掉一半通道
最终输出: x + cv2输出 → ❌ 错误![B,128,H,W] + [B,64,H,W] 形状不匹配!

# 错误剪枝后(两个BN都剪,但剪枝比例不同):
输入x形状: [B, 128, H, W]
cv1输出: [B, 96, H, W]     # cv1的BN剪掉25%
cv2输出: [B, 64, H, W]     # cv2的BN剪掉50%
最终输出: x + cv2输出 → ❌ 错误![B,128,H,W] + [B,64,H,W] 形状不匹配!

# 正确做法:要么都不剪,要么剪相同比例且保证最终形状一致
  • cv1的BN层:决定了主分支第一阶段的通道数
  • cv2的BN层:决定了主分支最终输出的通道数
  • 输入x:来自上一层的输出 这三个通道数必须相等,否则无法相加!

(2)需保持偶数才能剪枝(因为后面有chunk操作)是啥意思?

C2f中的split会把CBS分成两个分支,也就是chunk操作,这两个分支的通道数必须相同。

  • 假设cv1输出通道数为C
  • chunk(2, dim=1)会将C通道分成2份
  • 每份的通道数 = C // 2
# C2f的cv1层剪枝示例

原始cv1输出通道 = 128(偶数)
├── chunk操作 → 2份 × 64通道 ✓

剪枝比例30%后:
目标保留通道 = 128 × 0.7 = 89.6 → 实际保留90通道

# 情况1:保留90通道(偶数)
cv1输出 = 90通道
chunk操作 → 2份 × 45通道 ✓ 没问题!

# 情况2:保留89通道(奇数)
cv1输出 = 89通道
chunk操作 → ❌ 89不能被2整除,程序崩溃!

# 代码中的处理逻辑
if name in chunk_bn_list and mask.sum() % 2 == 1:
    # 如果剪枝后是奇数,调整阈值直到得到偶数
    flattened_sorted_weight = torch.sort(module.weight.data.abs().view(-1))[0]
    idx = torch.min(torch.nonzero(flattened_sorted_weight.gt(thre))).item()
    thre_ = flattened_sorted_weight[idx - 1] - 1e-6  # 略低于阈值
    mask = module.weight.data.abs().gt(thre_).float()
    # 这样会多保留一个通道,使总数变成偶数
3.3.5.2 YOLOv8s各类型BN层详细统计

在 YOLOv8 的官方代码(ultralytics/cfg/models/v8/yolov8.yaml)中,模型的深度和宽度是通过缩放因子控制的:

  • 深度因子 (depth_multiple) :控制网络层数(即 C2f 中 bottleneck 的堆叠数量)。
  • 宽度因子 (width_multiple) :控制通道数。

1a2b5ae3-5c43-4670-9182-e527cc5db1fb.png

6ebb02f5-f112-4473-94e7-0e4638fb8dc1.png

d652e3f0-198d-4f3f-8f61-e4a43eed5c2e.png 6d43dc98-5dae-423f-b958-4b39eceba790.png

BN层数 = 57
不可剪枝BN层 = 8 (残差连接)
可剪枝BN层 = 57 - 8 = 49个

其中:
- 普通可剪枝: 49 - 8 = 41个
- 需保持偶数: 8

YOLOv8s BN层总数 = Backbone(27) + Neck(18) + Head(12) = 57个BN层

可用于Network Slimming剪枝的BN层 = 49个

3.3.6 基于gamma分布确定剪枝阈值

第1步:收集所有可剪枝BN层的gamma值

第2步:根据剪枝比例计算阈值

把所有的gamma值从小到大排序,根据要剪枝的比例进行剪枝

def compute_prune_threshold(bn_dict, sorted_bn, prune_ratio):
    """
    第2步:根据剪枝比例计算阈值
    
    Args:
        sorted_bn: 排序后的gamma值 [γ₁, γ₂, γ₃, ..., γₙ] (γ₁最小, γₙ最大)
        prune_ratio: 要剪枝的比例 (如0.3表示剪掉30%的通道)
    
    计算:
        threshold_idx = int(len(sorted_bn) * prune_ratio)
        threshold = sorted_bn[threshold_idx]
    
    可视化理解:
        
        gamma分布直方图:
        count
         ↑
         |        🟨
         |      🟨🟨🟨
         |    🟨🟨🟨🟨🟨
         |  🟨🟨🟨🟨🟨🟨🟨
         |🟨🟨🟨🟨🟨🟨🟨🟨🟨
         +------------------------→ gamma值
               ↑
             threshold
         (剪掉左边30%)
    
    物理意义:
        - 所有gamma值小于threshold的通道都是"不重要"的
        - 这些通道将被剪掉
        - threshold越大,剪掉的通道越多
    """
    
    # 计算用户指定比例下的剪枝阈值
    thre = sorted_bn[int(len(sorted_bn) * prune_ratio)]
    
    # 安全检查:确保不会剪掉整层
    highest_thre = min([module.weight.data.abs().max() for module in bn_dict.values()])
    percent_limit = (sorted_bn == highest_thre).nonzero()[0, 0].item() / len(sorted_bn)
    
    print(f'阈值: {thre:.4f} (剪掉{prune_ratio:.1%}的通道)')
    print(f'安全警告:阈值不应超过{highest_thre:.4f},否则会剪掉整层')
    
    return thre

确定剪枝比例的主要方法:

  • 基于目标压缩率(最常用)
def determine_prune_ratio_by_target(target_flops_reduction=0.5):
    """
    根据目标FLOPs减少量反推剪枝比例
    
    原理:FLOPs减少 ≈ (剪枝比例)^2
    因为:
    - 卷积FLOPs ≈ 输出通道 × 输入通道 × K² × H × W
    - 剪枝比例p会同时减少输入和输出通道
    - 所以FLOPs减少 ≈ 1 - (1-p)²
    """
    
    # FLOPs减少与剪枝比例的关系
    # 剪枝比例p → FLOPs减少 = 1 - (1-p)²
    # 反推:p = 1 - √(1 - FLOPs减少)
    
    target_flops_reduction = 0.5  # 目标减少50% FLOPs
    p = 1 - math.sqrt(1 - target_flops_reduction)
    print(f"目标FLOPs减少{target_flops_reduction:.0%} → 需要剪枝比例{p:.1%}")
    
    return p

# 经验参考值:
flops_to_prune_ratio = {
    '减少30% FLOPs': 0.16,  # 1 - √(0.7) ≈ 0.16
    '减少40% FLOPs': 0.23,  # 1 - √(0.6) ≈ 0.23
    '减少50% FLOPs': 0.29,  # 1 - √(0.5) ≈ 0.29
    '减少60% FLOPs': 0.37,  # 1 - √(0.4) ≈ 0.37
    '减少70% FLOPs': 0.45,  # 1 - √(0.3) ≈ 0.45
}
  • 基于gamma分布自动确定
def determine_prune_ratio_from_gamma(sorted_gamma):
    """
    根据gamma分布自动确定合适的剪枝比例
    
    原理:寻找gamma分布的"拐点"或"自然分割点"
    """
    
    # 方法1:基于统计分布
    import numpy as np
    
    gamma_array = sorted_gamma.numpy()
    
    # 计算累积分布
    cumulative = np.arange(1, len(gamma_array)+1) / len(gamma_array)
    
    # 方法1.1:找到gamma开始显著增大的点
    # 计算梯度(一阶差分)
    gradients = np.diff(gamma_array)
    
    # 找到梯度突然变大的位置(gamma开始显著增大)
    threshold_idx = np.argmax(gradients > np.percentile(gradients, 90))
    auto_prune_ratio = threshold_idx / len(gamma_array)
    
    # 方法1.2:使用K-means将gamma分成两组
    from sklearn.cluster import KMeans
    
    # 将gamma值reshape成2D数组
    X = gamma_array.reshape(-1, 1)
    
    # 用K-means分成两类(重要和不重要)
    kmeans = KMeans(n_clusters=2, random_state=0).fit(X)
    
    # 找到两个簇的中心
    centers = kmeans.cluster_centers_.flatten()
    
    # 较小的簇中心对应不重要的通道
    small_center = min(centers)
    large_center = max(centers)
    
    # 阈值取两个中心的中间值
    threshold = (small_center + large_center) / 2
    
    # 计算小于阈值的比例
    cluster_prune_ratio = (gamma_array < threshold).mean()
    
    print(f"自动确定的剪枝比例:")
    print(f"  基于梯度: {auto_prune_ratio:.1%}")
    print(f"  基于聚类: {cluster_prune_ratio:.1%}")
    
    return auto_prune_ratio, cluster_prune_ratio

3.3.7 如何确定保留哪些通道

生成通道掩码

def compute_bn_mask(model, ignore_bn_list, chunk_bn_list, thre):
    """
    为每个BN层生成保留通道的掩码
    """
    maskbndict = {}
    
    for name, module in model.model.named_modules():
        if isinstance(module, nn.BatchNorm2d):
            # 获取该层原始通道数
            origin_channels = module.weight.data.size()[0]
            
            if name in ignore_bn_list:
                # 情况1:不能剪枝的层(残差连接)
                mask = torch.ones(origin_channels)  # 全1掩码,全部保留
                
            else:
                # 情况2:可以剪枝的层
                # 核心:gamma > threshold 的通道保留,否则剪掉
                mask = module.weight.data.abs().gt(thre).float()
                # .gt(thre) 返回布尔张量,.float() 转为0/1
                
                # 特殊情况:C2f的cv1层需要偶数通道
                if name in chunk_bn_list and mask.sum() % 2 == 1:
                    # 如果保留通道数是奇数,调整阈值
                    mask = adjust_to_even(module, thre)
            
            # 对gamma和beta应用掩码(乘以0相当于剪掉)
            module.weight.data.mul_(mask)   # 不重要的通道gamma变为0
            module.bias.data.mul_(mask)     # 对应的bias也变为0
            
            maskbndict[name] = mask
            print(f"{name}: {origin_channels}{mask.sum().int()}通道")
    
    return maskbndict

掩码的可视化理解

"""
以model.4.cv1.bn为例(128通道):

原始gamma值:
[0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.50, 0.60, 0.70, 0.80, ...]
  ↓     ↓     ↓     ↓     ↓     ↓     ↓     ↓     ↓     ↓
 0.01 < 阈值? 0.04
  ↓
 
比较结果(布尔值):
[True, True, True, False, False, False, False, False, False, False, ...]
  ↓     ↓     ↓     ↓       ↓       ↓       ↓       ↓       ↓       ↓
 
转成float掩码:
[1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
  ↑     ↑     ↑     ↑
保留  保留  保留  剪掉

应用掩码后gamma值:
[0.01, 0.02, 0.03, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]

完整流程示例

"""
假设有一个简单的网络:

Layer1: Conv(3→64) + BN1(64)
Layer2: Conv(64→128) + BN2(128)
Layer3: Conv(128→256) + BN3(256)

剪枝比例30%,阈值=0.15

Step 1: 收集gamma
BN1 gamma: [0.1, 0.2, 0.1, 0.3, 0.05, 0.4, ...] 64个
BN2 gamma: [0.2, 0.02, 0.3, 0.01, 0.4, 0.03, ...] 128个
BN3 gamma: [0.01, 0.5, 0.02, 0.6, 0.03, 0.7, ...] 256个

Step 2: 排序所有gamma,找到阈值
所有gamma排序后,第30%分位数 = 0.15

Step 3: 为每层生成掩码
BN1: gamma > 0.15? → 保留45个,剪掉19个
BN2: gamma > 0.15? → 保留90个,剪掉38个
BN3: gamma > 0.15? → 保留180个,剪掉76个

Step 4: 参数重映射
Layer1 Conv: 
  - 原始: [64, 3, 3, 3]
  - 输出掩码保留45通道 → [45, 3, 3, 3]
  - 输入是RGB 3通道,不需要剪
  
Layer2 Conv:
  - 原始: [128, 64, 3, 3]
  - 输出掩码保留90通道 → [90, 64, 3, 3]
  - 输入掩码来自BN1保留的45通道 → [90, 45, 3, 3]
  
Layer3 Conv:
  - 原始: [256, 128, 3, 3]
  - 输出掩码保留180通道 → [180, 128, 3, 3]
  - 输入掩码来自BN2保留的90通道 → [180, 90, 3, 3]
"""

3.4 阶段四:剪枝后的网络模型微调

3.4.1 剪枝后的网络模型微调finetune.py

python finetune.py