大家好,这里是好评笔记,本文为试读,查看全文请移步公主号:Goodnote。本文详细介绍模型训练完成后的压缩和优化技术:蒸馏、剪枝、量化。
@[toc]
模型压缩和优化技术是为了在保证模型性能(精度、推理速度等)的前提下,减少模型的体积、降低计算复杂度和内存占用,从而提高模型在资源受限环境中的部署效率。这些技术对于在边缘设备、移动设备等计算资源有限的场景中部署深度学习模型尤为重要。以下是几种常见的模型压缩和优化技术的解释:
1. 知识蒸馏 (Knowledge Distillation)
知识蒸馏是一种通过“教师模型”(通常是一个性能较高但规模较大的模型)来指导“学生模型”(通常是一个较小但高效的模型)训练的技术。其基本思想是让学生模型学习教师模型在输入数据上的输出分布,而不是直接学习真实标签。主要步骤如下:
- 训练教师模型: 首先训练一个大规模的教师模型,该模型通常有很好的性能。
- 蒸馏训练: 使用教师模型的预测结果(软标签)来训练学生模型。通常情况下,学生模型会通过一种称为“蒸馏损失”(Distillation Loss)的函数来最小化其输出与教师模型输出的差异。
- 优势: 知识蒸馏可以有效地提升学生模型的精度,即使学生模型结构相对简单,也能获得接近教师模型的性能。
推荐阅读:一文搞懂【知识蒸馏】【Knowledge Distillation】算法原理
基本概念
知识蒸馏(Knowledge Distillation)是一种将大模型的知识迁移到小模型的方法,旨在保持模型性能的同时,减少模型的参数量和计算复杂度。知识蒸馏广泛用于深度学习中模型压缩和加速的任务,使得小模型能够在有限资源的设备(如手机、嵌入式设备)上高效运行,同时仍保持高精度。
知识蒸馏通过训练一个小模型(学生模型) 来 模仿 一个 大模型(教师模型) 的行为。大模型的输出(通常是类别概率分布或特征表示)作为小模型的“软标签”或监督信号,使小模型能够更好地学习复杂的数据分布。
知识蒸馏可以分为以下几种基本形式:
- 软标签蒸馏:通过教师模型的输出概率作为目标,使得学生模型不仅学习正确的分类,还学习类别之间的相对关系。
- 中间层蒸馏:将教师模型的中间层表示传递给学生模型,使其学习更丰富的特征表示。
- 基于特征的蒸馏:直接从教师模型的隐藏层特征提取知识,并将其应用于学生模型。
工作流程
知识蒸馏的整个流程确保了小模型在有限资源的设备上高效运行,同时保留了教师模型的精度。这种方法被广泛应用于边缘计算、移动应用和其他对计算资源敏感的场景。
| 步骤 | 详细操作 |
|---|---|
| 训练教师模型 | 训练一个高精度的大模型,作为学生模型学习的知识源 |
| 准备软标签 | 通过温度调节生成教师模型的软标签,提供类别间相对关系信息 |
| 构建学生模型 | 设计一个小而高效的模型,用于模仿教师模型的行为 |
| 构建损失函数 | 使用软标签和硬标签损失的组合,以平衡学生模型对硬标签和软标签的学习 |
| 训练学生模型 | 通过前向传播、反向传播和参数更新迭代优化学生模型,模仿教师模型的输出 |
| 评估模型 | 对比教师和学生模型的性能,确保学生模型在效率和精度上的平衡 |
| 部署学生模型 | 导出学生模型到目标平台,进行量化、剪枝等优化,并在真实环境中进行测试 并部署 |
-
训练教师模型
- 目标:知识蒸馏的第一步是训练一个高精度的大模型,也就是教师模型。教师模型通常具有较大的参数量和复杂的结构,能有效学习到数据的复杂模式。
- 训练:教师模型通常在完整数据集上进行标准的监督学习训练,以确保其在任务上的性能足够好(例如分类任务中达到较高的准确率)。教师模型的高精度和强泛化能力为学生模型提供了可靠的“知识源”。
- 优化:教师模型可以使用标准的损失函数(例如分类任务中的交叉熵损失)进行优化。教师模型的最终性能将直接影响学生模型的学习效果,因此需要仔细调优确保教师模型的高质量。
-
准备教师模型的输出
-
目标:在知识蒸馏中,教师模型的输出不再是简单的硬标签(one-hot),而是称为“软标签”的类别概率分布。软标签提供了类别间的细微关系,是学生模型的重要学习目标。
-
温度调节:教师模型的输出通常使用温度调节(temperature scaling)进行平滑。具体来说,教师模型在生成输出的 softmax 概率分布时会加入温度参数 ( ),以平滑各类别之间的概率分布。
-
输出软标签:经过温度调节后的 softmax 输出(软标签)会被保存下来,作为学生模型的目标。软标签比硬标签包含了更多类别间的信息,有助于学生模型更细致地学习数据分布。
-
教师模型生成的软标签的计算公式:
- ( p_i ):第 ( i ) 类的概率(软标签)。
- ( z_i ):第 ( i ) 类的 logit(教师模型输出的未归一化分数)。
- ( T ):温度参数,用于控制软化程度。
-
公式参数解释
Logits( ):Logits 是教师模型在最后一层但是没有经过 softmax的输出(在应用 softmax 之前),通常表示各类别的非归一化得分。
温度参数():温度参数用于调节 softmax 函数的输出分布。在知识蒸馏中,通过调整温度参数 ( ) 的值,教师模型可以生成更加平滑的概率分布,从而帮助学生模型学习类别之间的相对关系。
- 当 ( ) 时,这个公式就变成了普通的 softmax 函数,输出的概率分布直接对应教师模型对各类别的置信度。
- 当 ( ) 时,输出分布变得更加平滑,使得非最大类的概率变得较大,利于学生模型捕捉到类间关系。
温度参数 ( T ) 的作用
- 更高的温度(即)会使得 logits 被缩放得更小,从而使 softmax 函数的输出分布更平滑。这意味着各类别的概率差异会缩小,学生模型可以更好地理解不同类别之间的相对关系,而不仅仅关注于概率最高的类别。
- 通过这种方式,学生模型在训练时不仅学习到正确答案的类别标签,还学习到不同类别之间的关系(即类间相似性)。这有助于学生模型在实际应用中对未见数据具有更好的泛化能力。
-
构建学生模型
- 目标:学生模型通常比教师模型小,具有更少的参数量。它的目的是在保持教师模型精度的同时,显著降低计算和存储需求,以便在资源受限的设备(如手机、嵌入式设备)上高效运行。
- 设计:学生模型可以与教师模型具有相同的结构,但层数、参数量较少;也可以是其他架构,甚至与教师模型完全不同。学生模型的设计通常会根据目标硬件的限制来优化,以在保持精度的前提下达到更高的计算效率。
- 初始化:学生模型的权重可以从头初始化,也可以使用预训练模型的权重作为初始状态,以加快训练收敛速度。
-
构建损失函数
详细全文请移步公主号:Goodnote。