1. 核心思想
Network Slimming 的核心是利用Batch Normalization(BN)层中的缩放因子γ(gamma参数) 作为通道重要性的度量指标,用L1范数将BN的放缩因子推向0,对不重要的通道channel进行剪枝。
2. 基本原理
2.1 BN层的缩放因子
其中γ是可学习的缩放参数:
- γ的大小与对应通道的重要性相关
- γ值小的通道对网络输出的贡献小,可以将γ值小的通道剪掉
- gamma值大 → 通道重要(缩放因子大,对输出贡献大)
- gamma值接近0 → 通道不重要(对输出贡献小)
公式含义:
2.2 L1稀疏化训练
为了让γ更好地反映通道重要性,在训练时对γ施加L1正则化:
其中:
L_original:原始损失函数λ:稀疏化惩罚系数,平衡任务性能和模型稀疏度∑|γ|:所有BN层γ参数的L1范数
# 关键理解点:
1. λ ≈ 对γ参数的惩罚强度
2. λ越大 → γ越容易被推到0 → 更稀疏但可能欠拟合
3. λ越小 → γ保持原始分布 → 更密集但剪枝困难
4. 最优λ = 刚好让不重要通道γ≈0,重要通道γ保持较大
公式含义:
3. 具体步骤
准备工作
(1)划分训练集和验证集
(2)迁移学习预训练模型权重文件
(3)数据集配置文件:cfg/datasets/VOC.yaml
(4)train.py、稀疏训练文件sparsity_train.py、模型剪枝文件、剪枝后微调文件
【正确流程】
yolov8s.pt (COCO)
↓
[正常训练] ← 在你的数据集上训练到收敛(重点!)
↓
best.pt (你的任务) ← γ已适应你的数据
↓
[稀疏化训练] ← 只做稀疏化,不做领域适应
↓
γ分布理想(重要通道大,不重要通道被压制)
↓
剪枝效果好
↓
精度保持好
3.1 阶段一:正常训练自己的数据集
python train.py 或者
yolo detect train data=D:/ultralytics/ultralytics/cfg/datasets/VOC-ball.yaml
model=D:/ultralytics/ultralytics/weights/yolov8s.pt epochs=100 imgsz=640 batch=16
workers=4
得到best.pt
3.2 阶段二:稀疏化训练
(1)稀疏化训练时怎么设置λ值(L1稀疏化正则化惩罚系数)?
需要根据数据集调整,可以通过观察tensorboard的map,缩放因子gamma变化直方图等选择。
训练中需要通过tensorboard监控训练过程,特别是map变化,bn分布变化等,在runs/train/目录下有events.out.tfevents.* 文件。
3.2.1 sr设置的黄金经验法则
1. 根据数据集规模调整
def suggest_sr_by_dataset_size(train_size): """根据训练集大小推荐sr""" if train_size < 1000: # 小数据集 return 0.0005 elif train_size < 10000: # 中数据集 return 0.001 else: # 大数据集 return 0.00152. 根据类别数调整
def adjust_sr_by_classes(base_sr, num_classes): """类别越多越保守""" if num_classes > 50: return base_sr * 0.8 # 多类别,减小sr elif num_classes < 10: return base_sr * 1.2 # 少类别,可增大sr return base_sr3. 根据目标剪枝率调整
def suggest_sr_by_prune_ratio(prune_ratio): """ 根据目标剪枝率推荐sr prune_ratio: 期望剪枝的比例 (0.3 = 剪掉30%通道) """ if prune_ratio < 0.3: # 轻度剪枝 return 0.0005 elif prune_ratio < 0.5: # 中度剪枝 return 0.001 elif prune_ratio < 0.7: # 重度剪枝 return 0.0015 else: # 极限剪枝 return 0.002 # 但建议迭代剪枝
3.2.2 基于监控的动态调整策略
关键监控指标
监控指标 = {
"gamma分布": "观察直方图是否出现双峰",
"稀疏度": "γ < 1e-3的通道比例",
"mAP变化": "精度是否明显下降",
"L1损失占比": "L1损失/总损失的比例"
}
开始训练 (sr=0.001)
↓
监控训练过程
↓
检查gamma直方图
├──→ 全是小值(>80%接近0) → sr太大 → 减小到0.0003
├──→ 全是大值(<20%接近0) → sr太小 → 增大到0.002
└──→ 双峰分布(30-70%接近0) → 继续监控
↓
检查mAP
├──→ mAP下降明显 → sr可能偏大 → 适当减小
└──→ mAP稳定 → sr合适
(2)为什么不直接使用yolov8s.pt进行稀疏化训练?
不要直接用yolov8s.pt(官方预训练权重),而要用在你自己的数据集上正常训练完成的权重进行稀疏化训练。因为自己训练好的best.pt权重的 ①γ分布已经适应你的特定任务、②收敛速度快(已在最优点)、③稀疏效果更好、④精度更高。
方案A(从官方权重)的γ分布:
训练前: ████████░░░░░░░░ 均值0.89,稀疏度0%
稀疏化后: ████░░░░░░░░░░░░ 均值0.32,稀疏度68%
问题:许多重要通道也被压制,因为领域不适应
方案B(从自己best)的γ分布:
训练前: ██████░░░░░░░░░░ 均值0.72,稀疏度2%
稀疏化后: ██░░░░░░░░░░░░░░ 均值0.18,稀疏度72%
优势:真正不重要的通道被压制,重要通道保留
3.2.3 训练代码sparsity_train.py
python sparsity_train.py 得到sparsity_best.pt
import argparse
from ultralytics import YOLO
def train_model(opt):
# Load a model using the path specified in opt.weights
model = YOLO(opt.weights)
# Train the model with parameters specified in the opt argument
results = model.train(data=opt.data,
epochs=opt.epochs,
imgsz=opt.imgsz,
batch=opt.batch,
workers=opt.workers,
device=opt.device,
sr=opt.sr,
name=opt.name)
return results
def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='权重', help='path to model weights')
parser.add_argument('--data', type=str, default='数据集配置文件', help='dataset.yaml path')
parser.add_argument('--epochs', type=int, default=100, help='number of epochs')
parser.add_argument('--imgsz', type=int, default=640, help='input image size')
parser.add_argument('--batch', type=int, default=16, help='batch size')
parser.add_argument('--workers', type=int, default=4, help='number of workers for dataloader')
parser.add_argument('--device', nargs='+', type=int, default=[0], help='device id(s) for training, e.g., 0 or 0 1 2 3')
parser.add_argument('--sr', type=float, default=0.02, help='L1 regularization penalty coefficient (sparsity regularization)')
parser.add_argument('--name', type=str, default='train-sparse', help='save to project/name')
opt = parser.parse_args()
return opt
if __name__ == '__main__':
opt = parse_opt()
results = train_model(opt)
3.3 阶段三:网络模型剪枝
python prune.py
3.3.1 prune.py核心函数详解
整体流程概览
剪枝流程 = [
"Step 1: 准备数据 - 收集BN层信息",
"Step 2: 计算阈值 - 基于gamma分布确定剪枝阈值",
"Step 3: 重写配置 - 创建剪枝后的模型结构",
"Step 4: 生成掩码 - 确定哪些通道保留",
"Step 5: 构建模型 - 创建剪枝后的空模型",
"Step 6: 参数赋值 - 将保留的参数复制到新模型",
"Step 7: 保存模型 - 保存剪枝后的模型"
]
核心函数详解
prepare_prune_data - 准备剪枝数据
def prepare_prune_data(model):
"""
功能:收集BN层信息,识别特殊结构
返回值:
- bn_dict: 所有可剪枝的BN层
- ignore_bn_list: 不能剪枝的BN层(残差连接)
- chunk_bn_list: 需要偶数通道的BN层(C2f的chunk操作)
- sorted_bn: 排序后的gamma值,用于确定阈值
"""
compute_prune_threshold - 计算剪枝阈值
def compute_prune_threshold(bn_dict, sorted_bn, prune_ratio):
"""
功能:计算安全的剪枝阈值
三个关键值:
1. highest_thre: 最大安全阈值
- 取所有BN层gamma最大值的最大值
- 超过这个值会导致某一层被完全剪掉
2. percent_limit: 最大安全剪枝比例
- 对应highest_thre的百分位数
- 剪枝比例不能超过这个值
3. thre: 实际使用的阈值
- 根据用户指定的prune_ratio计算
- thre = sorted_bn[int(len(sorted_bn) * prune_ratio)]
示例输出:
Suggested Gamma threshold should be less than 0.5234, yours is 0.1567
The corresponding prune ratio should be less than 0.850, yours is 0.300
"""
compute_bn_mask - 计算BN层掩码
def compute_bn_mask(model, ignore_bn_list, chunk_bn_list, thre):
"""
功能:为每个BN层生成保留通道的掩码
掩码生成逻辑:
mask = gamma > thre ? 1 : 0
特殊处理:
1. 忽略剪枝的层:mask全为1
2. C2f的chunk操作:确保保留通道数为偶数
如果mask.sum()是奇数,降低阈值直到得到偶数
输出表格示例:
================================================================================
| layer name | origin channels | remaining channels |
================================================================================
| model.0.conv.bn | 64 | 42 |
| model.1.conv.bn | 128 | 85 |
| model.2.cv1.bn | 128 | 86 | # 偶数!
| model.2.cv2.bn | 128 | 86 |
| ... | ... | ... |
================================================================================
"""
assign_pruned_params - 参数赋值(最复杂的部分)
def assign_pruned_params(model, pruned_model, maskbndict):
"""
功能:将原始模型的参数复制到剪枝后的模型
关键数据结构:current_to_prev
- 定义在DetectionModelPruned中
- 记录当前层和前一层的依赖关系
- 例如:Conv层需要前一层的mask来确定输入通道
"""
参数重映射的4种情况:
# 情况1: 普通Conv层
Conv层需要:
- 输出通道掩码:来自本层的BN层
- 输入通道掩码:来自前一层的BN层
# 情况2: Detect模块的最后一层(无BN)
特殊处理:只有输入通道需要掩码
if pattern_detect.fullmatch(name_org):
in_channels_mask = maskbndict[prev_bn_layer_name]
module_pruned.weight = module_org.weight[:, in_channels_mask, :, :]
# 情况3: C2f中Bottleneck的第一个卷积
需要从chunk后的第二部分取掩码
if pattern_c2f.fullmatch(currnet_bn_layer_name):
in_channels_mask = in_channels_mask.chunk(2, 0)[1] # 取后半部分
# 情况4: SPPF中的特殊卷积
需要复制4份掩码(因为SPPF有4个分支)
if name_org == "model.9.cv2.conv":
in_channels_mask = torch.cat([in_channels_mask for _ in range(4)], dim=0)
参数复制过程可视化:
原始Conv层权重 [out_channels, in_channels, k, k]
│
├─ 按out_channels_mask选择保留的输出通道
│ ↓
[保留的out_channels, in_channels, k, k]
│
├─ 按in_channels_mask选择保留的输入通道
│ ↓
[保留的out_channels, 保留的in_channels, k, k]
│
└─ 赋值给剪枝后的Conv层
3.3.2 特殊结构处理
C2f模块的chunk操作
# C2f forward中的关键代码
def forward(self, x):
y = list(self.cv1(x).chunk(2, 1)) # 在通道维度chunk成2份
# 要求:cv1的输出通道数必须是偶数!
# 剪枝时的处理
if name in chunk_bn_list and mask.sum() % 2 == 1:
# 如果剪枝后剩余通道是奇数
# 调整阈值直到得到偶数
flattened_sorted_weight = torch.sort(module.weight.data.abs().view(-1))[0]
idx = torch.min(torch.nonzero(flattened_sorted_weight.gt(thre))).item()
thre_ = flattened_sorted_weight[idx - 1] - 1e-6 # 略低于阈值
mask = module.weight.data.abs().gt(thre_).float()
残差连接的约束
# 残差连接要求:两个分支的输出通道数必须相等
if module.add: # 有残差连接的Bottleneck
ignore_bn_list.append(f"{name[:-4]}.cv1.bn") # 第一个分支的BN
ignore_bn_list.append(f"{name}.cv2.bn") # 第二个分支的BN
# 这两个层都不参与剪枝,保持通道数一致
3.3.3 剪枝前后的模型变化
模块替换
# 原始模块 -> 剪枝后模块
C2f → C2fPruned
SPPF → SPPFPruned
Detect → DetectPruned
DetectionModel → DetectionModelPruned
# 剪枝后模块的主要变化:
1. 添加了通道数记录功能
2. 支持根据mask动态调整通道数
3. 保存了层间依赖关系(current_to_prev)
通道数变化示例
Layer 原始通道 → 剪枝后通道
----------------------------------------
Conv1 64 → 42
Conv2 128 → 85
C2f.cv1 128 → 86 (调整为偶数)
C2f.cv2 128 → 86
... 512 → 312
Detect 256 → 168
3.3.4 剪枝常见问题
3.3.5 BN层分类与剪枝可行性
3.3.5.1 从代码逻辑看BN层分类
# 代码中的三类BN层
bn_dict = {} # 所有BN层
ignore_bn_list = [] # 不能剪枝的BN层(残差连接)
chunk_bn_list = [] # 需要保持偶数的BN层(可剪枝但有约束)
# 实际可剪枝的BN层 = bn_dict - ignore_bn_list
# 注意:chunk_bn_list中的层仍然可以剪枝,只是需要保证偶数
(1)为什么残差连接的BN层不能剪枝?
残差连接输出的通道数由BN层的通道数决定,剪枝会改变conv的输出通道数导致x和F(x)无法相加。
# 假设我们错误地剪枝了残差连接的BN层
原始情况:
输入x形状: [B, 128, H, W]
cv1输出: [B, 128, H, W] # 经过cv1(包含BN)后的特征
cv2输出: [B, 128, H, W] # 经过cv2(包含BN)后的特征
最终输出: x + cv2输出 = [B, 128, H, W] # 形状相同才能相加
# 错误剪枝后(只剪cv2的BN,不剪cv1的BN):
输入x形状: [B, 128, H, W] # 不变
cv1输出: [B, 128, H, W] # cv1的BN没剪,通道数不变
cv2输出: [B, 64, H, W] # cv2的BN剪掉一半通道
最终输出: x + cv2输出 → ❌ 错误![B,128,H,W] + [B,64,H,W] 形状不匹配!
# 错误剪枝后(两个BN都剪,但剪枝比例不同):
输入x形状: [B, 128, H, W]
cv1输出: [B, 96, H, W] # cv1的BN剪掉25%
cv2输出: [B, 64, H, W] # cv2的BN剪掉50%
最终输出: x + cv2输出 → ❌ 错误![B,128,H,W] + [B,64,H,W] 形状不匹配!
# 正确做法:要么都不剪,要么剪相同比例且保证最终形状一致
- cv1的BN层:决定了主分支第一阶段的通道数
- cv2的BN层:决定了主分支最终输出的通道数
- 输入x:来自上一层的输出 这三个通道数必须相等,否则无法相加!
(2)需保持偶数才能剪枝(因为后面有chunk操作)是啥意思?
C2f中的split会把CBS分成两个分支,也就是chunk操作,这两个分支的通道数必须相同。
- 假设cv1输出通道数为C
- chunk(2, dim=1)会将C通道分成2份
- 每份的通道数 = C // 2
# C2f的cv1层剪枝示例
原始cv1输出通道 = 128(偶数)
├── chunk操作 → 2份 × 64通道 ✓
剪枝比例30%后:
目标保留通道 = 128 × 0.7 = 89.6 → 实际保留90通道
# 情况1:保留90通道(偶数)
cv1输出 = 90通道
chunk操作 → 2份 × 45通道 ✓ 没问题!
# 情况2:保留89通道(奇数)
cv1输出 = 89通道
chunk操作 → ❌ 89不能被2整除,程序崩溃!
# 代码中的处理逻辑
if name in chunk_bn_list and mask.sum() % 2 == 1:
# 如果剪枝后是奇数,调整阈值直到得到偶数
flattened_sorted_weight = torch.sort(module.weight.data.abs().view(-1))[0]
idx = torch.min(torch.nonzero(flattened_sorted_weight.gt(thre))).item()
thre_ = flattened_sorted_weight[idx - 1] - 1e-6 # 略低于阈值
mask = module.weight.data.abs().gt(thre_).float()
# 这样会多保留一个通道,使总数变成偶数
3.3.5.2 YOLOv8s各类型BN层详细统计
在 YOLOv8 的官方代码(ultralytics/cfg/models/v8/yolov8.yaml)中,模型的深度和宽度是通过缩放因子控制的:
- 深度因子 (depth_multiple) :控制网络层数(即 C2f 中 bottleneck 的堆叠数量)。
- 宽度因子 (width_multiple) :控制通道数。
总BN层数 = 57
不可剪枝BN层 = 8 (残差连接)
可剪枝BN层 = 57 - 8 = 49个
其中:
- 普通可剪枝: 49 - 8 = 41个
- 需保持偶数: 8个
YOLOv8s BN层总数 = Backbone(27) + Neck(18) + Head(12) = 57个BN层
可用于Network Slimming剪枝的BN层 = 49个
3.3.6 基于gamma分布确定剪枝阈值
第1步:收集所有可剪枝BN层的gamma值
第2步:根据剪枝比例计算阈值
把所有的gamma值从小到大排序,根据要剪枝的比例进行剪枝
def compute_prune_threshold(bn_dict, sorted_bn, prune_ratio):
"""
第2步:根据剪枝比例计算阈值
Args:
sorted_bn: 排序后的gamma值 [γ₁, γ₂, γ₃, ..., γₙ] (γ₁最小, γₙ最大)
prune_ratio: 要剪枝的比例 (如0.3表示剪掉30%的通道)
计算:
threshold_idx = int(len(sorted_bn) * prune_ratio)
threshold = sorted_bn[threshold_idx]
可视化理解:
gamma分布直方图:
count
↑
| 🟨
| 🟨🟨🟨
| 🟨🟨🟨🟨🟨
| 🟨🟨🟨🟨🟨🟨🟨
|🟨🟨🟨🟨🟨🟨🟨🟨🟨
+------------------------→ gamma值
↑
threshold
(剪掉左边30%)
物理意义:
- 所有gamma值小于threshold的通道都是"不重要"的
- 这些通道将被剪掉
- threshold越大,剪掉的通道越多
"""
# 计算用户指定比例下的剪枝阈值
thre = sorted_bn[int(len(sorted_bn) * prune_ratio)]
# 安全检查:确保不会剪掉整层
highest_thre = min([module.weight.data.abs().max() for module in bn_dict.values()])
percent_limit = (sorted_bn == highest_thre).nonzero()[0, 0].item() / len(sorted_bn)
print(f'阈值: {thre:.4f} (剪掉{prune_ratio:.1%}的通道)')
print(f'安全警告:阈值不应超过{highest_thre:.4f},否则会剪掉整层')
return thre
确定剪枝比例的主要方法:
- 基于目标压缩率(最常用)
def determine_prune_ratio_by_target(target_flops_reduction=0.5):
"""
根据目标FLOPs减少量反推剪枝比例
原理:FLOPs减少 ≈ (剪枝比例)^2
因为:
- 卷积FLOPs ≈ 输出通道 × 输入通道 × K² × H × W
- 剪枝比例p会同时减少输入和输出通道
- 所以FLOPs减少 ≈ 1 - (1-p)²
"""
# FLOPs减少与剪枝比例的关系
# 剪枝比例p → FLOPs减少 = 1 - (1-p)²
# 反推:p = 1 - √(1 - FLOPs减少)
target_flops_reduction = 0.5 # 目标减少50% FLOPs
p = 1 - math.sqrt(1 - target_flops_reduction)
print(f"目标FLOPs减少{target_flops_reduction:.0%} → 需要剪枝比例{p:.1%}")
return p
# 经验参考值:
flops_to_prune_ratio = {
'减少30% FLOPs': 0.16, # 1 - √(0.7) ≈ 0.16
'减少40% FLOPs': 0.23, # 1 - √(0.6) ≈ 0.23
'减少50% FLOPs': 0.29, # 1 - √(0.5) ≈ 0.29
'减少60% FLOPs': 0.37, # 1 - √(0.4) ≈ 0.37
'减少70% FLOPs': 0.45, # 1 - √(0.3) ≈ 0.45
}
- 基于gamma分布自动确定
def determine_prune_ratio_from_gamma(sorted_gamma):
"""
根据gamma分布自动确定合适的剪枝比例
原理:寻找gamma分布的"拐点"或"自然分割点"
"""
# 方法1:基于统计分布
import numpy as np
gamma_array = sorted_gamma.numpy()
# 计算累积分布
cumulative = np.arange(1, len(gamma_array)+1) / len(gamma_array)
# 方法1.1:找到gamma开始显著增大的点
# 计算梯度(一阶差分)
gradients = np.diff(gamma_array)
# 找到梯度突然变大的位置(gamma开始显著增大)
threshold_idx = np.argmax(gradients > np.percentile(gradients, 90))
auto_prune_ratio = threshold_idx / len(gamma_array)
# 方法1.2:使用K-means将gamma分成两组
from sklearn.cluster import KMeans
# 将gamma值reshape成2D数组
X = gamma_array.reshape(-1, 1)
# 用K-means分成两类(重要和不重要)
kmeans = KMeans(n_clusters=2, random_state=0).fit(X)
# 找到两个簇的中心
centers = kmeans.cluster_centers_.flatten()
# 较小的簇中心对应不重要的通道
small_center = min(centers)
large_center = max(centers)
# 阈值取两个中心的中间值
threshold = (small_center + large_center) / 2
# 计算小于阈值的比例
cluster_prune_ratio = (gamma_array < threshold).mean()
print(f"自动确定的剪枝比例:")
print(f" 基于梯度: {auto_prune_ratio:.1%}")
print(f" 基于聚类: {cluster_prune_ratio:.1%}")
return auto_prune_ratio, cluster_prune_ratio
3.3.7 如何确定保留哪些通道
生成通道掩码
def compute_bn_mask(model, ignore_bn_list, chunk_bn_list, thre):
"""
为每个BN层生成保留通道的掩码
"""
maskbndict = {}
for name, module in model.model.named_modules():
if isinstance(module, nn.BatchNorm2d):
# 获取该层原始通道数
origin_channels = module.weight.data.size()[0]
if name in ignore_bn_list:
# 情况1:不能剪枝的层(残差连接)
mask = torch.ones(origin_channels) # 全1掩码,全部保留
else:
# 情况2:可以剪枝的层
# 核心:gamma > threshold 的通道保留,否则剪掉
mask = module.weight.data.abs().gt(thre).float()
# .gt(thre) 返回布尔张量,.float() 转为0/1
# 特殊情况:C2f的cv1层需要偶数通道
if name in chunk_bn_list and mask.sum() % 2 == 1:
# 如果保留通道数是奇数,调整阈值
mask = adjust_to_even(module, thre)
# 对gamma和beta应用掩码(乘以0相当于剪掉)
module.weight.data.mul_(mask) # 不重要的通道gamma变为0
module.bias.data.mul_(mask) # 对应的bias也变为0
maskbndict[name] = mask
print(f"{name}: {origin_channels} → {mask.sum().int()}通道")
return maskbndict
掩码的可视化理解
"""
以model.4.cv1.bn为例(128通道):
原始gamma值:
[0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.50, 0.60, 0.70, 0.80, ...]
↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓
0.01 < 阈值? 0.04
↓
比较结果(布尔值):
[True, True, True, False, False, False, False, False, False, False, ...]
↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓
转成float掩码:
[1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
↑ ↑ ↑ ↑
保留 保留 保留 剪掉
应用掩码后gamma值:
[0.01, 0.02, 0.03, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
完整流程示例
"""
假设有一个简单的网络:
Layer1: Conv(3→64) + BN1(64)
Layer2: Conv(64→128) + BN2(128)
Layer3: Conv(128→256) + BN3(256)
剪枝比例30%,阈值=0.15
Step 1: 收集gamma
BN1 gamma: [0.1, 0.2, 0.1, 0.3, 0.05, 0.4, ...] 64个
BN2 gamma: [0.2, 0.02, 0.3, 0.01, 0.4, 0.03, ...] 128个
BN3 gamma: [0.01, 0.5, 0.02, 0.6, 0.03, 0.7, ...] 256个
Step 2: 排序所有gamma,找到阈值
所有gamma排序后,第30%分位数 = 0.15
Step 3: 为每层生成掩码
BN1: gamma > 0.15? → 保留45个,剪掉19个
BN2: gamma > 0.15? → 保留90个,剪掉38个
BN3: gamma > 0.15? → 保留180个,剪掉76个
Step 4: 参数重映射
Layer1 Conv:
- 原始: [64, 3, 3, 3]
- 输出掩码保留45通道 → [45, 3, 3, 3]
- 输入是RGB 3通道,不需要剪
Layer2 Conv:
- 原始: [128, 64, 3, 3]
- 输出掩码保留90通道 → [90, 64, 3, 3]
- 输入掩码来自BN1保留的45通道 → [90, 45, 3, 3]
Layer3 Conv:
- 原始: [256, 128, 3, 3]
- 输出掩码保留180通道 → [180, 128, 3, 3]
- 输入掩码来自BN2保留的90通道 → [180, 90, 3, 3]
"""
3.4 阶段四:剪枝后的网络模型微调
3.4.1 剪枝后的网络模型微调finetune.py
python finetune.py