DepGraph结构化剪枝原理深度解析
https://github.com/VainF/Torch-Pruning
核心思想:依赖图(Dependency Graph)
一句话概括:DepGraph通过分析神经网络的计算图依赖关系,自动识别并同步剪枝相互依赖的层,确保剪枝后模型结构完整可运行。
一、为什么需要依赖图?
传统剪枝的问题
# 简单网络示例
网络结构: Conv1 → BN1 → ReLU → Conv2 → BN2
# 传统剪枝(可能出错):
剪枝 Conv1.out_channels[0,2,4] # 剪掉第0,2,4个通道
→ BN1.running_mean[0,2,4] 需要同步剪枝 ✅
→ Conv2.in_channels[对应位置] 需要同步剪枝 ❌ 容易遗漏!
→ 模型运行时维度不匹配:RuntimeError
DepGraph的解决方案
DepGraph分析步骤:
1. 构建计算图:识别所有层间连接
2. 标记依赖:Conv1.out→BN1.in, BN1.out→Conv2.in
3. 剪枝传播:剪枝Conv1时,自动标记需要同步剪枝的BN1和Conv2
4. 执行剪枝:一次性剪掉所有相关层的对应通道
二、DepGraph算法详细原理
1. 依赖类型分类
DepGraph识别三种关键依赖:
2. 依赖图构建算法
# 伪代码:构建依赖图
class DependencyGraph:
def build_dependency(self, model, example_inputs):
# 1. 前向传播追踪计算图
self.trace_computational_graph(model, example_inputs)
# 2. 识别层间依赖关系
for layer in model.modules():
# 输入依赖:哪些层的输出是本层的输入
self.find_input_dependencies(layer)
# 输出依赖:本层的输出被哪些层使用
self.find_output_dependencies(layer)
# 3. 建立依赖边
for src_layer, dst_layer in self.dependency_pairs:
edge_type = self.classify_dependency(src_layer, dst_layer)
self.add_edge(src_layer, dst_layer, edge_type)
3. 依赖传播规则
依赖传播示例(YOLOv8 C2f模块):
┌─── CV0 ───┐
输入 ──┤ ├─→ 拼接 → CV2 → 输出
└─── CV1 ───┘
↓
Bottleneck
↓
Shortcut
依赖链:
剪枝CV0.out_channels[i] →
必须同步剪枝:拼接层输入位置[i] + CV2.in_channels对应位置
剪枝CV1.out_channels[j] →
必须同步剪枝:Bottleneck.in_channels[j] +
Bottleneck.shortcut对应通道 +
拼接层另一个输入位置
三、DepGraph在torch_pruning中的实现
1. 核心数据结构
# torch_pruning中的关键类
class DependencyGraph:
def __init__(self):
self.nodes = [] # 图节点(每个层)
self.edges = [] # 依赖边
self.dependency_rules = {} # 依赖规则表
class Node:
def __init__(self, module):
self.module = module # PyTorch模块
self.dependencies = [] # 依赖此节点的节点
self.dependents = [] # 此节点依赖的节点
self.pruning_fn = None # 剪枝函数
self.idxs = None # 要剪枝的索引
2. 依赖规则库
torch_pruning内置了大量依赖规则:
# 部分依赖规则示例
DEPENDENCY_RULES = {
(nn.Conv2d, nn.BatchNorm2d): {
'type': 'channel-wise',
'src_dim': 0, # Conv的out_channels维度
'dst_dim': 0, # BN的num_features维度
'propagate': True # 需要传播
},
(nn.BatchNorm2d, nn.Conv2d): {
'type': 'channel-wise',
'src_dim': 0, # BN的num_features
'dst_dim': 1, # Conv的in_channels
'propagate': True
},
(nn.Conv2d, nn.Conv2d): {
'type': 'residual', # 残差连接
'requires_same_channels': True
}
}
3. 剪枝计划生成
def get_pruning_plan(self, root_module, pruning_fn, idxs):
"""生成安全的剪枝计划"""
plan = []
# 从根节点开始BFS遍历依赖图
visited = set()
queue = [(root_module, pruning_fn, idxs)]
while queue:
module, fn, idxs = queue.pop(0)
if module in visited:
continue
# 添加当前层剪枝操作
plan.append(PruningAction(module, fn, idxs))
visited.add(module)
# 查找依赖层
for dep_module, rule in self.get_dependencies(module):
# 根据规则计算依赖层需要剪枝的索引
dep_idxs = self.propagate_indices(idxs, rule)
dep_fn = self.get_pruning_fn(dep_module, rule)
queue.append((dep_module, dep_fn, dep_idxs))
return PruningPlan(plan) # 包含所有需要同步剪枝的操作
四、以YOLOv8 C2f模块为例
C2f结构分析
# YOLOv8 C2f模块结构
C2f(
cv1: Conv2d(128, 128, kernel_size=1) # 输入通道128,输出128
m: ModuleList(Bottleneck * 3) # 3个Bottleneck
cv2: Conv2d(256, 128, kernel_size=1) # 输入256,输出128
)
# Bottleneck内部:
Bottleneck(
cv1: Conv2d(64, 64, kernel_size=3) # 分组卷积
cv2: Conv2d(64, 64, kernel_size=3)
add: True # 有shortcut连接
)
DepGraph分析过程
步骤1:构建计算图
输入 → cv1 → split → [part1, part2]
part1 → cv2
part2 → bottleneck1 → bottleneck2 → bottleneck3 → 拼接 → cv2 → 输出
步骤2:识别关键依赖
1. cv1.out_channels 依赖:
- split的输入通道
- part1和part2的对应位置
2. bottleneck中的依赖:
- cv1.out 与 cv2.in 必须匹配(分组卷积)
- shortcut要求输入输出通道相等
步骤3:剪枝传播示例
剪枝cv1的第[0,2,4]通道:
→ split需要移除对应输入位置
→ part1需要移除对应位置
→ part2需要移除对应位置
→ bottleneck1.cv1.in_channels对应位置需要移除
→ bottleneck1.cv2.in_channels对应位置需要移除
→ ... 自动传播到所有相关层
五、剪枝执行机制
1. 索引传播算法
def propagate_indices(src_idxs, dependency_rule, src_module, dst_module):
"""计算依赖层需要剪枝的索引"""
if dependency_rule['type'] == 'channel-wise':
src_dim = dependency_rule['src_dim']
dst_dim = dependency_rule['dst_dim']
# 获取维度信息
src_shape = getattr(src_module, 'weight').shape
dst_shape = getattr(dst_module, 'weight').shape
# 计算映射关系
if src_dim == 0 and dst_dim == 1: # Conv_out → Conv_in
# 需要找到哪些输入通道对应被剪枝的输出通道
mapping = get_channel_mapping(src_module, dst_module)
dst_idxs = [mapping[i] for i in src_idxs]
elif src_dim == dst_dim: # Conv_out → BN
# 直接对应
dst_idxs = src_idxs
elif dependency_rule['type'] == 'residual':
# 残差连接:必须剪枝相同索引
dst_idxs = src_idxs
return sorted(list(set(dst_idxs)))
2. 剪枝操作执行
class PruningPlan:
def exec(self):
"""执行剪枝计划"""
# 步骤1:验证计划可行性
self.validate()
# 步骤2:按拓扑排序执行(避免冲突)
sorted_actions = self.topological_sort()
# 步骤3:执行剪枝
for action in sorted_actions:
module = action.module
pruning_fn = action.fn
idxs = action.idxs
# 调用剪枝函数
new_module = pruning_fn(module, idxs)
# 替换原模块
replace_module(module, new_module)
# 更新依赖图
self.update_dependencies(module, new_module)
六、DepGraph vs 传统方法的对比
性能对比
计算复杂度
设网络有L层,平均每层有D个依赖:
传统方法:O(L) - 只考虑单层
DepGraph:O(L × D) - 但实际中D很小(通常2-4)
实际影响:DepGraph增加~10-20%预处理时间,
但避免100%的运行时错误风险。
七、数学形式化描述
八、实际应用启示
为什么你的YOLOv8项目需要DepGraph
# YOLOv8特有的复杂结构
需要DepGraph处理:
1. C2f中的split-concat操作
2. Bottleneck的shortcut连接
3. PAN-FPN的多尺度融合
4. Detect头的多分支输出
# 没有DepGraph的后果
手动处理这些依赖:
- 需要深入理解YOLOv8源码
- 每个版本更新都要重新分析
- 极易出错且难以调试
DepGraph的局限性
- 计算开销:构建依赖图需要一次前向传播
- 内存占用:存储依赖图需要额外内存
- 动态图支持有限:对条件分支处理不够完美
- 自定义层需要注册:新层类型需要手动添加依赖规则
DepGraph算法完整步骤
步骤①:建模网络层之间的相互依赖关系
1.1 依赖类型识别
DepGraph识别三种核心依赖关系:
1.2 依赖图构建算法
class DependencyGraphBuilder:
def build(self, model, example_input):
"""构建依赖图的核心算法"""
# 1. 前向传播追踪计算图
computation_graph = self.trace_forward(model, example_input)
# 2. 识别所有层的输入输出关系
dependency_pairs = []
for layer_i in model.layers:
for layer_j in model.layers:
if self.has_dependency(layer_i, layer_j, computation_graph):
# 3. 分类依赖类型
dep_type = self.classify_dependency(layer_i, layer_j)
dependency_pairs.append((layer_i, layer_j, dep_type))
# 4. 构建依赖图数据结构
self.graph = DependencyGraph(dependency_pairs)
return self.graph
def has_dependency(self, src, dst, comp_graph):
"""判断是否存在依赖关系"""
# 条件1: dst的输入包含src的输出
if src.output in dst.inputs:
return True
# 条件2: 共享参数(如分组卷积)
if self.share_parameters(src, dst):
return True
# 条件3: 结构约束(如残差连接)
if self.structural_constraint(src, dst):
return True
return False
步骤②:对耦合参数进行分组
2.1 分组原理
核心思想:将必须同时剪枝的参数分为一组,确保剪枝后结构完整性。
# 参数分组算法
def group_coupled_parameters(dependency_graph):
"""基于依赖图进行参数分组"""
groups = []
visited = set()
# 使用并查集(Union-Find)算法分组
uf = UnionFind()
# 根据依赖关系合并参数
for (layer_i, layer_j, dep_type) in dependency_graph.edges:
if dep_type == 'channel-wise' or dep_type == 'residual':
# 获取参数的全局索引
params_i = get_parameter_indices(layer_i)
params_j = get_parameter_indices(layer_j)
# 合并到同一组
for p_i in params_i:
for p_j in params_j:
uf.union(p_i, p_j)
# 生成最终分组
for param_idx in range(total_parameters):
root = uf.find(param_idx)
if root not in visited:
group = [p for p in range(total_parameters) if uf.find(p) == root]
groups.append(group)
visited.add(root)
return groups
2.2 分组示例:YOLOv8 C2f模块
C2f模块的依赖分组示例:
Layer1 (cv1.conv): weight[128,128,1,1]
Layer2 (cv1.bn): weight[128], bias[128], running_mean[128]
Layer3 (cv2.conv): weight[128,256,1,1]
依赖关系:
cv1.conv.out_channels[0:127] ↔ cv1.bn.features[0:127] # 一一对应
cv1.conv.out_channels[0:127] ↔ cv2.conv.in_channels[对应位置] # 部分对应
分组结果:
Group1: [cv1.conv.weight[:,0], cv1.bn.weight[0], cv1.bn.bias[0],
cv2.conv.weight[对应输入位置,:]]
Group2: [cv1.conv.weight[:,1], cv1.bn.weight[1], cv1.bn.bias[1],
cv2.conv.weight[对应输入位置,:]]
...
Group128: [对应第128个通道的所有参数]
2.3 分组数学表达
步骤③:学习组内的一致稀疏性
3.1 组级稀疏性学习
关键创新:不是独立剪枝每个参数,而是以组为单位学习稀疏模式。
def learn_group_sparsity(model, groups, sparsity_target=0.5):
"""学习组内一致的稀疏性"""
# 1. 为每个组计算重要性得分
group_scores = []
for group in groups:
# 方法1: L2范数重要性(常用)
score = 0
for param_idx in group:
param = get_parameter_by_index(model, param_idx)
score += torch.norm(param.data, p=2).item()
group_scores.append(score / len(group))
# 2. 根据重要性排序并选择要剪枝的组
sorted_indices = np.argsort(group_scores) # 分数低的组重要性低
# 3. 确定剪枝阈值
num_prune = int(len(groups) * sparsity_target)
prune_groups = sorted_indices[:num_prune]
# 4. 应用组级剪枝
for group_idx in prune_groups:
group = groups[group_idx]
for param_idx in group:
# 将整个组的参数置零(结构化剪枝)
param = get_parameter_by_index(model, param_idx)
param.data.zero_() # 或直接移除通道
return model
3.2 组内一致性的数学保证
3.3 优化目标函数
3.4 实际执行:软剪枝 vs 硬剪枝
# 软剪枝(训练时)
class GroupSparsityRegularizer:
def __init__(self, groups, strength=0.01):
self.groups = groups
self.strength = strength
def __call__(self, model):
loss = 0
for group in self.groups:
# 计算组内参数的L2范数
group_norm = 0
for param_idx in group:
param = get_parameter(model, param_idx)
group_norm += torch.norm(param, p=2)
# 鼓励整个组趋向于零
loss += self.strength * group_norm
return loss
# 硬剪枝(推理时)
def apply_group_pruning(model, groups, prune_ratio=0.3):
"""执行硬剪枝:直接移除通道"""
for group in groups:
# 计算组重要性
importance = compute_group_importance(group)
if importance < threshold:
# 移除整个组对应的通道
remove_channels(model, group)
完整工作流程示例
以卷积→BN→ReLU→卷积链为例
原始网络:Conv1 → BN1 → ReLU → Conv2
步骤1: 识别依赖
- Conv1.out_channels 与 BN1.features 一一对应
- BN1.features 与 Conv2.in_channels 部分对应
- Conv2的权重矩阵特定列依赖于Conv1的特定通道
步骤2: 参数分组
Group_i (对于第i个通道):
- Conv1.weight[:, i, :, :] # 第i个输出通道的所有权重
- BN1.weight[i], BN1.bias[i] # 对应的BN参数
- Conv2.weight[:, 对应位置, :, :] # Conv2的对应输入通道
步骤3: 组级剪枝
- 计算每个Group_i的重要性得分
- 剪枝得分最低的30%的组
- 组内所有参数同步移除
数学验证剪枝后一致性
DepGraph的优势与创新
相比传统方法的优势
- 理论保证:数学上确保剪枝后模型结构完整
- 自动化程度高:自动识别依赖,无需手动指定
- 适用性广:可处理复杂网络结构(DenseNet, ResNet, YOLO等)
- 剪枝质量高:组级剪枝保留重要的功能单元
在你的YOLOv8项目中的体现
# 实际代码中的体现
pruner = tp.pruner.GroupNormPruner(
model.model,
example_inputs,
importance=tp.importance.MagnitudeImportance(p=2), # 步骤③:学习稀疏性
iterative_steps=1,
pruning_ratio=0.1, # 稀疏度目标
ignored_layers=ignored_layers, # 特殊处理
unwrapped_parameters=unwrapped_parameters # 参数分组信息
)
# torch_pruning内部自动执行步骤①和②
DepGraph剪枝类型定位
DepGraph剪枝属于依赖感知的通道剪枝,使用L2幅度作为重要性依据,但它的实现机制很特殊。
GroupNormPruner:负责执行DepGraph的依赖分析和参数分组MagnitudeImportance(p=2):负责在每个组内部,使用传统的L2范数评估重要性
DepGraph的特殊性
# DepGraph实际上做的是:结构化通道剪枝 + 依赖感知
传统通道剪枝: 独立决定每个通道是否剪枝
DepGraph通道剪枝: 考虑依赖关系,成组剪枝通道
# 示例:卷积→BN→ReLU链
剪枝Conv的通道[0,2,4] → 必须同步剪枝BN的[0,2,4] → 这就是DepGraph的核心
剪枝依据(Importance Criterion)
DepGraph本身是剪枝框架,需要配合具体的重要性评估准则。
1. 幅度重要性(Magnitude)
这是最常用的依据
importance=tp.importance.MagnitudeImportance(p=2) # L2范数
# 原理:权重的L2范数大小决定重要性
重要性分数 = ||W[:,c,:,:]||₂ # 第c个通道所有权重的L2范数
分数低的通道 → 重要性低 → 优先剪枝
2. 其他可选依据
torch_pruning支持多种重要性准则:
# 可选的重要性评估方法
重要性准则 = {
"L1幅度": tp.importance.MagnitudeImportance(p=1), # L1范数
"L2幅度": tp.importance.MagnitudeImportance(p=2), # 你的选择
"BN缩放因子": tp.importance.BNScaleImportance(), # BN的gamma参数
"随机重要性": tp.importance.RandomImportance(), # 随机测试用
"海森矩阵": tp.importance.HessianImportance(), # 二阶信息(计算量大)
}
项目中具体体现
# 你的足球检测模型剪枝过程
剪枝决策 = {
"剪枝类型": "通道剪枝(结构化)",
"依据": "L2幅度重要性", # MagnitudeImportance(p=2)
"评估对象": "卷积层的输出通道",
"分组方式": "DepGraph自动分组(依赖感知)",
"剪枝粒度": "以组为单位,不是单个通道",
"数学原理": "||W_c||₂ 小的通道先剪",
"目标": "从1114万参数减少40-50%"
}
选择L2幅度是合理的选择,因为:
- YOLOv8有大量卷积层,且很多层没有BN
- L2计算简单,适合迭代剪枝(16次迭代)
- 与预训练权重兼容:大模型训练时通常用L2正则化,权重已反映重要性