DepGraph结构化剪枝原理与算法步骤

59 阅读11分钟

DepGraph结构化剪枝原理深度解析

aca6d409-449c-49e9-8eb1-583d2b2534d2.png

https://github.com/VainF/Torch-Pruning

核心思想:依赖图(Dependency Graph)

一句话概括:DepGraph通过分析神经网络的计算图依赖关系,自动识别并同步剪枝相互依赖的层,确保剪枝后模型结构完整可运行。

 一、为什么需要依赖图?

传统剪枝的问题

# 简单网络示例
网络结构: Conv1 → BN1 → ReLU → Conv2 → BN2

# 传统剪枝(可能出错):
剪枝 Conv1.out_channels[0,2,4]  # 剪掉第0,2,4个通道
→ BN1.running_mean[0,2,4] 需要同步剪枝 ✅
→ Conv2.in_channels[对应位置] 需要同步剪枝 ❌ 容易遗漏!
→ 模型运行时维度不匹配:RuntimeError

DepGraph的解决方案

DepGraph分析步骤:
1. 构建计算图:识别所有层间连接
2. 标记依赖:Conv1.out→BN1.in, BN1.out→Conv2.in
3. 剪枝传播:剪枝Conv1时,自动标记需要同步剪枝的BN1和Conv2
4. 执行剪枝:一次性剪掉所有相关层的对应通道

 二、DepGraph算法详细原理

1. 依赖类型分类

DepGraph识别三种关键依赖:

c0bcb47f-f558-4c4c-800c-3f041cf0e5b5.png

2. 依赖图构建算法

# 伪代码:构建依赖图
class DependencyGraph:
    def build_dependency(self, model, example_inputs):
        # 1. 前向传播追踪计算图
        self.trace_computational_graph(model, example_inputs)
        
        # 2. 识别层间依赖关系
        for layer in model.modules():
            # 输入依赖:哪些层的输出是本层的输入
            self.find_input_dependencies(layer)
            
            # 输出依赖:本层的输出被哪些层使用  
            self.find_output_dependencies(layer)
            
        # 3. 建立依赖边
        for src_layer, dst_layer in self.dependency_pairs:
            edge_type = self.classify_dependency(src_layer, dst_layer)
            self.add_edge(src_layer, dst_layer, edge_type)

3. 依赖传播规则

依赖传播示例(YOLOv8 C2f模块):
        ┌─── CV0 ───┐
输入 ──┤             ├─→ 拼接 → CV2 → 输出
        └─── CV1 ───┘
              ↓
          Bottleneck
              ↓
           Shortcut

依赖链:
剪枝CV0.out_channels[i] → 
  必须同步剪枝:拼接层输入位置[i] + CV2.in_channels对应位置

剪枝CV1.out_channels[j] → 
  必须同步剪枝:Bottleneck.in_channels[j] + 
              Bottleneck.shortcut对应通道 + 
              拼接层另一个输入位置

三、DepGraph在torch_pruning中的实现

1. 核心数据结构

# torch_pruning中的关键类
class DependencyGraph:
    def __init__(self):
        self.nodes = []      # 图节点(每个层)
        self.edges = []      # 依赖边
        self.dependency_rules = {}  # 依赖规则表
    
    class Node:
        def __init__(self, module):
            self.module = module        # PyTorch模块
            self.dependencies = []      # 依赖此节点的节点
            self.dependents = []        # 此节点依赖的节点
            self.pruning_fn = None      # 剪枝函数
            self.idxs = None            # 要剪枝的索引

2. 依赖规则库

torch_pruning内置了大量依赖规则:

# 部分依赖规则示例
DEPENDENCY_RULES = {
    (nn.Conv2d, nn.BatchNorm2d): {
        'type': 'channel-wise',
        'src_dim': 0,  # Conv的out_channels维度
        'dst_dim': 0,  # BN的num_features维度
        'propagate': True  # 需要传播
    },
    (nn.BatchNorm2d, nn.Conv2d): {
        'type': 'channel-wise', 
        'src_dim': 0,  # BN的num_features
        'dst_dim': 1,  # Conv的in_channels
        'propagate': True
    },
    (nn.Conv2d, nn.Conv2d): {
        'type': 'residual',  # 残差连接
        'requires_same_channels': True
    }
}

3. 剪枝计划生成

def get_pruning_plan(self, root_module, pruning_fn, idxs):
    """生成安全的剪枝计划"""
    plan = []
    
    # 从根节点开始BFS遍历依赖图
    visited = set()
    queue = [(root_module, pruning_fn, idxs)]
    
    while queue:
        module, fn, idxs = queue.pop(0)
        
        if module in visited:
            continue
            
        # 添加当前层剪枝操作
        plan.append(PruningAction(module, fn, idxs))
        visited.add(module)
        
        # 查找依赖层
        for dep_module, rule in self.get_dependencies(module):
            # 根据规则计算依赖层需要剪枝的索引
            dep_idxs = self.propagate_indices(idxs, rule)
            dep_fn = self.get_pruning_fn(dep_module, rule)
            
            queue.append((dep_module, dep_fn, dep_idxs))
    
    return PruningPlan(plan)  # 包含所有需要同步剪枝的操作

四、以YOLOv8 C2f模块为例

C2f结构分析

# YOLOv8 C2f模块结构
C2f(
  cv1: Conv2d(128, 128, kernel_size=1)  # 输入通道128,输出128
  m: ModuleList(Bottleneck * 3)         # 3个Bottleneck
  cv2: Conv2d(256, 128, kernel_size=1)  # 输入256,输出128
)

# Bottleneck内部:
Bottleneck(
  cv1: Conv2d(64, 64, kernel_size=3)    # 分组卷积
  cv2: Conv2d(64, 64, kernel_size=3)
  add: True  # 有shortcut连接
)

DepGraph分析过程

步骤1:构建计算图
输入 → cv1 → split → [part1, part2]
part1 → cv2
part2 → bottleneck1 → bottleneck2 → bottleneck3 → 拼接 → cv2 → 输出

步骤2:识别关键依赖
1. cv1.out_channels 依赖:
   - split的输入通道
   - part1和part2的对应位置

2. bottleneck中的依赖:
   - cv1.out 与 cv2.in 必须匹配(分组卷积)
   - shortcut要求输入输出通道相等

步骤3:剪枝传播示例
剪枝cv1的第[0,2,4]通道:
→ split需要移除对应输入位置
→ part1需要移除对应位置
→ part2需要移除对应位置
→ bottleneck1.cv1.in_channels对应位置需要移除
→ bottleneck1.cv2.in_channels对应位置需要移除
→ ... 自动传播到所有相关层

 五、剪枝执行机制

1. 索引传播算法

def propagate_indices(src_idxs, dependency_rule, src_module, dst_module):
    """计算依赖层需要剪枝的索引"""
    
    if dependency_rule['type'] == 'channel-wise':
        src_dim = dependency_rule['src_dim']
        dst_dim = dependency_rule['dst_dim']
        
        # 获取维度信息
        src_shape = getattr(src_module, 'weight').shape
        dst_shape = getattr(dst_module, 'weight').shape
        
        # 计算映射关系
        if src_dim == 0 and dst_dim == 1:  # Conv_out → Conv_in
            # 需要找到哪些输入通道对应被剪枝的输出通道
            mapping = get_channel_mapping(src_module, dst_module)
            dst_idxs = [mapping[i] for i in src_idxs]
            
        elif src_dim == dst_dim:  # Conv_out → BN
            # 直接对应
            dst_idxs = src_idxs
            
    elif dependency_rule['type'] == 'residual':
        # 残差连接:必须剪枝相同索引
        dst_idxs = src_idxs
        
    return sorted(list(set(dst_idxs)))

2. 剪枝操作执行

class PruningPlan:
    def exec(self):
        """执行剪枝计划"""
        # 步骤1:验证计划可行性
        self.validate()
        
        # 步骤2:按拓扑排序执行(避免冲突)
        sorted_actions = self.topological_sort()
        
        # 步骤3:执行剪枝
        for action in sorted_actions:
            module = action.module
            pruning_fn = action.fn
            idxs = action.idxs
            
            # 调用剪枝函数
            new_module = pruning_fn(module, idxs)
            
            # 替换原模块
            replace_module(module, new_module)
            
            # 更新依赖图
            self.update_dependencies(module, new_module)

六、DepGraph vs 传统方法的对比

性能对比

12b0d66c-6ffd-445e-8652-041306c8b86c.png

计算复杂度

设网络有L层,平均每层有D个依赖:

传统方法:O(L) - 只考虑单层
DepGraph:O(L × D) - 但实际中D很小(通常2-4)

实际影响:DepGraph增加~10-20%预处理时间,
但避免100%的运行时错误风险。

七、数学形式化描述

8cab09f2-8b7d-4953-9a24-dfcb888acc39.png

八、实际应用启示

为什么你的YOLOv8项目需要DepGraph

# YOLOv8特有的复杂结构
需要DepGraph处理:
1. C2f中的split-concat操作
2. Bottleneck的shortcut连接
3. PAN-FPN的多尺度融合
4. Detect头的多分支输出

# 没有DepGraph的后果
手动处理这些依赖:
- 需要深入理解YOLOv8源码
- 每个版本更新都要重新分析
- 极易出错且难以调试

DepGraph的局限性

  1. 计算开销:构建依赖图需要一次前向传播
  2. 内存占用:存储依赖图需要额外内存
  3. 动态图支持有限:对条件分支处理不够完美
  4. 自定义层需要注册:新层类型需要手动添加依赖规则

DepGraph算法完整步骤

deepseek_mermaid_20260128_772399.png

步骤①:建模网络层之间的相互依赖关系

1.1 依赖类型识别

DepGraph识别三种核心依赖关系:

d4f395e8-124d-4901-8eb1-86b7b5b302ff.png

1.2 依赖图构建算法

class DependencyGraphBuilder:
    def build(self, model, example_input):
        """构建依赖图的核心算法"""
        
        # 1. 前向传播追踪计算图
        computation_graph = self.trace_forward(model, example_input)
        
        # 2. 识别所有层的输入输出关系
        dependency_pairs = []
        for layer_i in model.layers:
            for layer_j in model.layers:
                if self.has_dependency(layer_i, layer_j, computation_graph):
                    # 3. 分类依赖类型
                    dep_type = self.classify_dependency(layer_i, layer_j)
                    dependency_pairs.append((layer_i, layer_j, dep_type))
        
        # 4. 构建依赖图数据结构
        self.graph = DependencyGraph(dependency_pairs)
        return self.graph
    
    def has_dependency(self, src, dst, comp_graph):
        """判断是否存在依赖关系"""
        # 条件1: dst的输入包含src的输出
        if src.output in dst.inputs:
            return True
        
        # 条件2: 共享参数(如分组卷积)
        if self.share_parameters(src, dst):
            return True
            
        # 条件3: 结构约束(如残差连接)
        if self.structural_constraint(src, dst):
            return True
            
        return False

步骤②:对耦合参数进行分组

2.1 分组原理

核心思想:将必须同时剪枝的参数分为一组,确保剪枝后结构完整性。

# 参数分组算法
def group_coupled_parameters(dependency_graph):
    """基于依赖图进行参数分组"""
    
    groups = []
    visited = set()
    
    # 使用并查集(Union-Find)算法分组
    uf = UnionFind()
    
    # 根据依赖关系合并参数
    for (layer_i, layer_j, dep_type) in dependency_graph.edges:
        if dep_type == 'channel-wise' or dep_type == 'residual':
            # 获取参数的全局索引
            params_i = get_parameter_indices(layer_i)
            params_j = get_parameter_indices(layer_j)
            
            # 合并到同一组
            for p_i in params_i:
                for p_j in params_j:
                    uf.union(p_i, p_j)
    
    # 生成最终分组
    for param_idx in range(total_parameters):
        root = uf.find(param_idx)
        if root not in visited:
            group = [p for p in range(total_parameters) if uf.find(p) == root]
            groups.append(group)
            visited.add(root)
    
    return groups

2.2 分组示例:YOLOv8 C2f模块

C2f模块的依赖分组示例:

Layer1 (cv1.conv): weight[128,128,1,1]
Layer2 (cv1.bn):   weight[128], bias[128], running_mean[128]
Layer3 (cv2.conv): weight[128,256,1,1]

依赖关系:
cv1.conv.out_channels[0:127] ↔ cv1.bn.features[0:127]  # 一一对应
cv1.conv.out_channels[0:127] ↔ cv2.conv.in_channels[对应位置]  # 部分对应

分组结果:
Group1: [cv1.conv.weight[:,0], cv1.bn.weight[0], cv1.bn.bias[0], 
         cv2.conv.weight[对应输入位置,:]]
Group2: [cv1.conv.weight[:,1], cv1.bn.weight[1], cv1.bn.bias[1],
         cv2.conv.weight[对应输入位置,:]]
...
Group128: [对应第128个通道的所有参数]

2.3 分组数学表达

58d472db-e05a-4af2-b2aa-8b8e882e6655.png

步骤③:学习组内的一致稀疏性

3.1 组级稀疏性学习

关键创新:不是独立剪枝每个参数,而是以组为单位学习稀疏模式

def learn_group_sparsity(model, groups, sparsity_target=0.5):
    """学习组内一致的稀疏性"""
    
    # 1. 为每个组计算重要性得分
    group_scores = []
    for group in groups:
        # 方法1: L2范数重要性(常用)
        score = 0
        for param_idx in group:
            param = get_parameter_by_index(model, param_idx)
            score += torch.norm(param.data, p=2).item()
        group_scores.append(score / len(group))
    
    # 2. 根据重要性排序并选择要剪枝的组
    sorted_indices = np.argsort(group_scores)  # 分数低的组重要性低
    
    # 3. 确定剪枝阈值
    num_prune = int(len(groups) * sparsity_target)
    prune_groups = sorted_indices[:num_prune]
    
    # 4. 应用组级剪枝
    for group_idx in prune_groups:
        group = groups[group_idx]
        for param_idx in group:
            # 将整个组的参数置零(结构化剪枝)
            param = get_parameter_by_index(model, param_idx)
            param.data.zero_()  # 或直接移除通道
    
    return model

3.2 组内一致性的数学保证

f44af9bc-4afe-461e-bc6f-bb1830c29205.png

3.3 优化目标函数

63e95742-58e0-4360-a536-4bf51d747068.png

3.4 实际执行:软剪枝 vs 硬剪枝

# 软剪枝(训练时)
class GroupSparsityRegularizer:
    def __init__(self, groups, strength=0.01):
        self.groups = groups
        self.strength = strength
    
    def __call__(self, model):
        loss = 0
        for group in self.groups:
            # 计算组内参数的L2范数
            group_norm = 0
            for param_idx in group:
                param = get_parameter(model, param_idx)
                group_norm += torch.norm(param, p=2)
            
            # 鼓励整个组趋向于零
            loss += self.strength * group_norm
        return loss

# 硬剪枝(推理时)
def apply_group_pruning(model, groups, prune_ratio=0.3):
    """执行硬剪枝:直接移除通道"""
    for group in groups:
        # 计算组重要性
        importance = compute_group_importance(group)
        
        if importance < threshold:
            # 移除整个组对应的通道
            remove_channels(model, group)

完整工作流程示例

以卷积→BN→ReLU→卷积链为例

原始网络:Conv1 → BN1 → ReLU → Conv2

步骤1: 识别依赖
- Conv1.out_channels 与 BN1.features 一一对应
- BN1.features 与 Conv2.in_channels 部分对应
- Conv2的权重矩阵特定列依赖于Conv1的特定通道

步骤2: 参数分组
Group_i (对于第i个通道):
- Conv1.weight[:, i, :, :]   # 第i个输出通道的所有权重
- BN1.weight[i], BN1.bias[i]  # 对应的BN参数
- Conv2.weight[:, 对应位置, :, :]  # Conv2的对应输入通道

步骤3: 组级剪枝
- 计算每个Group_i的重要性得分
- 剪枝得分最低的30%的组
- 组内所有参数同步移除

数学验证剪枝后一致性

8301b49c-408b-4661-8ed4-4b09c60ebcad.png

DepGraph的优势与创新

相比传统方法的优势

  1. 理论保证:数学上确保剪枝后模型结构完整
  2. 自动化程度高:自动识别依赖,无需手动指定
  3. 适用性广:可处理复杂网络结构(DenseNet, ResNet, YOLO等)
  4. 剪枝质量高:组级剪枝保留重要的功能单元

在你的YOLOv8项目中的体现

# 实际代码中的体现
pruner = tp.pruner.GroupNormPruner(
    model.model,
    example_inputs,
    importance=tp.importance.MagnitudeImportance(p=2),  # 步骤③:学习稀疏性
    iterative_steps=1,
    pruning_ratio=0.1,      # 稀疏度目标
    ignored_layers=ignored_layers,      # 特殊处理
    unwrapped_parameters=unwrapped_parameters  # 参数分组信息
)
# torch_pruning内部自动执行步骤①和②

DepGraph剪枝类型定位

DepGraph剪枝属于依赖感知的通道剪枝,使用L2幅度作为重要性依据,但它的实现机制很特殊。

97dd2fe8-df0f-44b0-b296-e5700ec81fc2.png

  1. GroupNormPruner:负责执行DepGraph的依赖分析参数分组
  2. MagnitudeImportance(p=2) :负责在每个组内部,使用传统的L2范数评估重要性

a0c066f2-3415-4206-a9d0-d1d345c249bd.png

DepGraph的特殊性

# DepGraph实际上做的是:结构化通道剪枝 + 依赖感知
传统通道剪枝: 独立决定每个通道是否剪枝
DepGraph通道剪枝: 考虑依赖关系,成组剪枝通道

# 示例:卷积→BN→ReLU链
剪枝Conv的通道[0,2,4] → 必须同步剪枝BN的[0,2,4] → 这就是DepGraph的核心

剪枝依据(Importance Criterion)

DepGraph本身是剪枝框架,需要配合具体的重要性评估准则

1. 幅度重要性(Magnitude)

这是最常用的依据

importance=tp.importance.MagnitudeImportance(p=2)  # L2范数

# 原理:权重的L2范数大小决定重要性
重要性分数 = ||W[:,c,:,:]||₂  # 第c个通道所有权重的L2范数
分数低的通道 → 重要性低 → 优先剪枝

2. 其他可选依据

torch_pruning支持多种重要性准则:

# 可选的重要性评估方法
重要性准则 = {
    "L1幅度": tp.importance.MagnitudeImportance(p=1),  # L1范数
    "L2幅度": tp.importance.MagnitudeImportance(p=2),  # 你的选择
    "BN缩放因子": tp.importance.BNScaleImportance(),   # BN的gamma参数
    "随机重要性": tp.importance.RandomImportance(),    # 随机测试用
    "海森矩阵": tp.importance.HessianImportance(),     # 二阶信息(计算量大)
}

deepseek_mermaid_20260129_c56512.png

45be2fa2-80ec-4f36-8dfa-5746b842ca90.png

项目中具体体现

# 你的足球检测模型剪枝过程
剪枝决策 = {
    "剪枝类型": "通道剪枝(结构化)",
    "依据": "L2幅度重要性",  # MagnitudeImportance(p=2)
    "评估对象": "卷积层的输出通道",
    "分组方式": "DepGraph自动分组(依赖感知)",
    "剪枝粒度": "以组为单位,不是单个通道",
    "数学原理": "||W_c||₂ 小的通道先剪",
    "目标": "从1114万参数减少40-50%"
}

选择L2幅度是合理的选择,因为:

  1. YOLOv8有大量卷积层,且很多层没有BN
  2. L2计算简单,适合迭代剪枝(16次迭代)
  3. 与预训练权重兼容:大模型训练时通常用L2正则化,权重已反映重要性