大模型参数微调系列教程(三):知识蒸馏Fine Tuning

171 阅读21分钟

自从DeepSeek大火之后,知识蒸馏 (Knowledge Distillation)也接续成为一个十分出圈的理论概念。然而,并未接触开发的同学对这块的认知相对来说还是比较浅显的。很多仅停留在概念层面,并未深刻掌握其精髓,也或者说理解的并未很透彻。

我们这里讲知识蒸馏实际上是用一个小模型(学生模型)学习大模型(教师模型)的知识,减少计算量,并提升推理速度的过程。这里我们选择浅层ResNet、MobileNet、EfficientNet 等轻量级模型作为学生模型。通过调整温度参数 (T)、损失权重等信息,让学生模型学习教师模型的“软标签”信息,提高训练效率。蒸馏后的小模型虽然加速了推理,但可能会有精度损失。可以通过调节如上模型参数信息来改善学生模型的学习效果。

图片

这里可能有读者要问,为什么知识蒸馏中经常用残差网络ResNet来作为教师模型和学生模型的示例进行训练和推理呢?

首先,ResNet(残差网络)在 ImageNet 和 CIFAR 等图像分类任务上是经典强基线。因此,ResNet 在分类任务中效果好。同时,ResNet稳定、易训练、准确率高,很多论文都用它当 benchmark。

其次,ResNet 架构模块化,易于裁剪和替换。比如:ResNet50 → ResNet18,本质上是层数减少而已。很适合用于知识蒸馏这种教师大模型 / 学生小模型对比的结构迁移。还有,就是残差结构(skip connection)本身能缓解训练困难。深层网络往往出现梯度消失/爆炸问题,残差结构可以让梯度“跨层流动”,训练更稳定。蒸馏训练本身就比较敏感,残差网络能提升整体稳定性和精度。

还有一个比较关键的问题,为什么硬标签用交叉熵计算,软标签用KL散度计算?

首先,我们要明白硬标签是啥?就是像[0, 0, 0, 1, 0, 0, 0, 0, 0, 0] 这样的 one-hot 标签,表示图像属于第 4 类。

交叉熵的本质是对概率分布之间差异的度量。在监督学习中,它衡量模型预测分布和真实分布之间的距离,常写作:

图片

其中,p 是真实分布(one-hot),q 是模型输出的softmax概率。

那为什么这种交叉熵又适合于硬标签呢?这是因为one-hot 标签中只有一个 pi=1,其他都是 0,所以交叉熵可简化为:

 

图片

也就是说:交叉熵只惩罚“正确类别的概率低”的情况,这也就是分类任务最标准的做法。

那么,我们又为什么在软标签中用KL 散度(Kullback–Leibler Divergence)来进行损失计算呢?

首先,我们同样要了解软标签是啥?软标签来自于教师模型预测的概率分布,例如:

teacher_pred = [0.1, 0.3, 0.05, 0.05, 0.4, 0.1]

这些概率带有丰富的“类间相似性”,不像 one-hot 那么“绝对”。

KL 散度是啥?KL 散度用于衡量两个概率分布之间的差异,KL 散度回答的是:

“如果现实是由 P 分布产生的,我们却用 Q 分布来压缩信息,那么我们会多花多少 bit 数?”如果 P 和 Q 完全一样,KL=0,说明没有信息损失;如果差距越大,KL 越大,说明 Q 远离了 P,代价越高。

数学公式是:

图片

P 是“教师模型”的 softmax 输出(目标分布),Q 是“学生模型”的 softmax 输出(预测分布)。

你可能会问:“那我直接比较概率差值(比如用 MSE)不行吗?”不行,因为 MSE 没有考虑概率的对数特性。KL 更符合“信息熵”原理,对小概率/高置信度的差异更敏感,这对模型学习到“细节上的犹豫”和“混淆类别”尤其重要。

我们来用图像直观地对比一下,在知识蒸馏中,硬标签(用于交叉熵)和软标签(用于KL散度) 到底有什么差别。这里我们的目标是展示一个样本在教师模型和真实标签下的输出分布对比图。这里我们用条形图对比三种标签/分布:Ground Truth(硬标签)、教师模型 softmax 输出(软标签)、学生模型 softmax 输出(可选对比)。

假设情景(十分类问题):某一张图真实是数字“3”,但教师模型输出了一个更“分布化”的结果:

图片

这里我们绘制一个条形图来可视化上述三种分布(硬标签、教师输出、学生输出)。

图片

 

上面这张图很好地说明了:硬标签(Ground Truth)只有真实类别3的概率是1,其余全是0。这是典型的 one-hot 向量,信息非常“尖锐”。教师模型输出(Soft Label)虽然也预测了类别3为最高概率(0.6),但其他类也有非零概率。表示模型对输入样本的**“模糊认知”**,体现了“它也有可能是4或5,但最像是3”。而学生模型输出则学习了教师的分布,虽然没有完全一致,但结构非常相似。这正是 KL 散度作为蒸馏损失在做的事——鼓励学生的输出逼近教师的分布。

总结起来,就是KL 散度不会只关注某一个类别,而是整体比较两个完整分布的形状是否相似。这非常适合在蒸馏中使用。因为教师模型的输出分布代表了“专家判断”,有些模糊性。学生模型通过 KL 散度学会“模仿”这种分布,继承了教师模型的“类间知识”。

现在要组织代码结构。可能需要以下部分:

1.导入必要的库,如PyTorch。

2.定义教师和学生模型。这里可能需要简化模型结构,比如使用ResNet作为教师,一个较浅的网络如CNN作为学生。

3.加载预训练教师模型的参数(这里需要提前进行预训练,采用交叉熵对比Labels的方式进行有效的分类)。

4.结合温度参数定义蒸馏损失函数,包括软目标和硬标签的损失。

5.设置优化器和训练循环。

6.超参数的设置部分,放在代码的开头,方便调整。

以下是一个基于 PyTorch 的知识蒸馏示例代码(截取关于教师模型到学生模型蒸馏的关键部分),展示如何通过教师模型(大模型)的输出指导学生模型(小模型)的训练,包含关键参数调节部分。其中,超参数设置放在顶部:temperature = 3; alpha = 0.7; beta = 3.0; learning_rate = 1e-3。

这里使用KL散度来计算软目标的损失,因为知识蒸馏中常常用到。最后,通过代码运行结果如下:

 

图片

从上图运行结果可以看到,整个学生模型参数量相对于教师模型参数量减少了一半。同时,学生模型的精度相对于教师模型精度略有一定程度的降低,推理时间上相对于教师模型减少了一些。大家很容易有这样的疑问,为什么学生模型的参数量减少了一半后,其推理时间并未减少一半呢?

首先,推理时间≠ 线性取决于参数量。虽然模型参数量减少了,但推理速度并不会线性加速。这是因为推理时间还受到以下因素影响:

图片

因此,从以上的分析结果可以看出,简单对教师模型蒸馏后生成学生模型时,并不能完全达到我们加速推理时间的预期。想要真正的让学生模型真正更快,则需要进一步对学生模型实行结构简化。比如,使用 MobileNetV3、EfficientNet 等专门设计的轻量模型。也可以考虑进一步对学生模型进行有效的量化/剪枝。比如,在知识蒸馏基础上,加上结构化剪枝和模型量化。进一步的,通过批量推理来增加 batch_size 看是否更能体现速度优势。当然,在部署优化上,可以使用 GPU 优化库:如 TensorRT、ONNX Runtime、OpenVINO 等部署优化工具来进一步提升性能。

接下来我们来对知识蒸馏中的关键要素信息,这里主要讲解知识蒸馏的模型参数如何调节,因此需要基于其中的关键参数进行如下说明。

1.调节学生模型神经网络

基于前序运行结果分析,如果想要进一步减少学生模型参数,则可以考虑替换学生模型本身的神经网络模型。比如,可以更换更轻量的模型结构,如 MobileNet、ShuffleNet、SqueezeNet 等。

图片

如下图所示是采用不同的轻量化学生模型后,不同学生模型参数数量、推理精度、推理时间和训练时间的对比分析。可以看到当不同轻量化学生模型,模型参数量不断减少后,其训练时间和推理精度都会同步减少。这个也很能理解,模型参数减少后,必然导致模型对输入的表达不够准确、完整和清晰了。

图片

当然,细心的同学会发现MobileNetV2 参数比 ShuffleNetV2 多,但精度却更低,这是为什么?

这其实是一个轻量模型结构与任务适应性之间的平衡问题,这是因为模型结构对小图像(如 CIFAR10)的适应性不同。MobileNetV2:原设计用于 224x224 图像,对 ImageNet 更适配。在输入缩小到 32x32 后,其特征图过早地下采样,导致信息丢失严重。ShuffleNetV2:更轻量化,且分组卷积设计更适合小输入图像,有效保留了低层信息。所以,虽然 MobileNetV2 的参数多,但“信息利用率”不高,造成精度反而下降。

其次,模型复杂度与表达能力也不是完全线性相关。MobileNetV2 的参数虽然多一些,但其瓶颈层结构对小任务可能是“过拟合”或“欠表达”。ShuffleNetV2 的架构更简单、计算路径更浅,对 CIFAR10 这类简单分类任务,可能刚刚好,反而更鲁棒。

当然,也有可能是训练轮数较少,复杂模型未能充分训练。目前只训练了 3个epoch,对于参数量更大、网络更深的 MobileNetV2,可能还没来得及充分学习任务模式,而 ShuffleNetV2 结构简单,收敛更快,因此在短时间训练下看起来效果更好。

2.温度参数(TEMPERATURE)

增大温度会使教师的 Softmax 分布更平滑,便于学生模型学习。T 决定了软标签中传递的信息量,且T 只影响蒸馏过程中的 loss 计算,不会改动模型结构本身。也就是说,T 是训练阶段的一个技巧,对部署阶段没有直接影响。那为什么调 T 有时候能提升学生精度?这是因为更大的 T 会给学生模型一个更平滑、信息更丰富的目标去模仿。它可能帮助学生更好地泛化,尤其在样本量有限或模型过小的情况下。但注意T太大会使所有 logits 都很接近,梯度太小,训练可能变慢或不稳定。最优的 T 一般通过 grid search(比如 3~8)去调。

如下图表示了两种不同的T值下,对应到学生模型具体的参数量、推理精度和推理时间对比分析。可以看到,当调大T值后,学生模型可以更多的继承教师模型的模型参数和泛化能力,这样学生模型的整个推理精度也会有一定程度的提高。

图片

 

这里可以看到T=4 的确带来了更好的效果,这是因为更高温度下,soft label 更平滑,提供了更丰富的“类间关系”信息,有助于学生模型学习教师模型的“类间认知”。但同时模型的训练耗时也增加了,虽然不是特别明显,但训练时间略有上升,可能是更平滑的目标函数导致每个 batch 的反向传播需要更精细的调整。此外,还可以观察到推理时间变化极小,这主要与模型结构相关,与 T 值无关。

这里需要注意如果进一步尝试更高温度(如 T=6, 8),但注意温度过高可能会让 logits 过于平滑,反而失去监督信息。我们进一步通过绘制图像分析不同温度值对推理结果精度的影响。

这里我们继续对比T值从1、2、4、6、8这几个不同的维度来看整个T值对模型输出的影响。

 

图片

从上图中可以看出,从 T=1 到 T=6 推理精度略有下降,而 T=8 时突然提升,这是一个很有意思的现象。我们知道T 越低(接近 1):soft label 趋于 one-hot,几乎和 hard label 类似。T 越高:soft label 越“平滑”,表示类别之间相似度的暗示越丰富。这里的中等 T 值(2~6)出现精度下降:可能是因为信息变得模糊但还是不足够丰富。即,产生了 teacher softmax 的平滑概率,但平滑得不够,导致学生模型失去了 hard label 的强监督效果,但又没有学到足够的类间结构知识,反而“中间状态最差”。最终,我们观测到T=8的时候,推理精度达到最佳状态了,这是因为学生模型在这个 T 值下学习到了更“平滑”且有效的指导。也就是说,我们通过对温度值的不断试探,找到了一个最佳温度值T=8,这个值可以获得一个比较好的推理结果。

3.训练轮次(EPOCHS)

基于如上温度值的调节,我们接续调节epoch参数,看下epoch分别在2、5、8下的不同训练结果值。

图片图片

 

可以看到,改变epoch后,整个学生参数量不变,训练时间随着epoch的增加的比较多。另一方面,推理精度随着epoch逐渐增加,但是推理时间基本保持恒定。那为什么训练时间随 Epoch 增加而增长,但推理时间基本不变?

这是一个非常常见的现象,这是训练时间增加的原因。每一个 epoch 都要对整个训练集进行一次完整的正向传播 + 反向传播 + 权重更新,这些操作计算量大。所以训练时间和 epoch 成正比,epoch 越多训练时间自然也就越长。

这里也需要注意推理时间几乎不变。这是因为推理时只进行正向传播,不涉及反向传播、梯度计算、参数更新。模型结构(即网络参数量和前向传播逻辑)在训练过程中不会变,所以只要模型结构不变,推理时间就基本稳定。这里测量的是整个 test_loader 推理一轮的时间,这个数据主要取决于模型复杂度和硬件环境,而不是训练轮次。

更通俗的讲,可以把训练看成是「学习写作业」,每次学习都得花一定时间思考和总结(训练时间会变长),而推理就像是「直接考试答题」,一次只考一份卷子(推理结构不变,所以答题时间也差不多)。

4.损失权重平衡

这里,需要注意代码中的参数命名和注释,清楚每个参数的作用。例如,温度T一般大于1,用来平滑概率分布,alpha和beta控制软硬损失的权重,总损失L = HARD_LOSS_WEIGHT* 硬损失 + DISTILL_LOSS_WEIGHT* 软损失。过程中,可能还需要指出关键参数的调整建议,例如温度参数和alpha/beta的平衡,以获得更好的蒸馏效果。不同任务或模型组合下,soft/hard loss 的重要程度是不一样的,调节权重的目的就是为了避免 soft loss 被掩盖,因为教师输出 softmax 后的概率通常更加“平”,值比较小,所以其 KL loss 的梯度也小,如果不放大(比如加上 T^2,或加大 β),在总损失中就不够有影响力。同时,考虑到适配不同模型容量差异,如果学生模型特别小,过多强调 hard loss 可能会让它难以收敛。如果教师模型特别强,就可以提高 soft loss 的比例,让学生多模仿老师。

此外,还可以通过对如上参数的调节实现训练目标的灵活性。比如有些任务更重视精度(hard label),有些更重视泛化(soft label),所以这个加权组合更像是“调味”行为,而不是概率加权,值不必加和为 1!因为软损失经过温度平滑和 KL 散度之后,它的数值通常 很小,远远小于交叉熵的值。 因此,为了让它在总体损失中有足够影响力,我们通常人为放大其权重。假如 soft loss 只有 0.1,而 hard loss 是 1.2,直接加权平均会导致 soft loss 几乎没影响。 所以尽量要设置大一些的蒸馏权重比如(DISTILL_LOSS_WEIGHT = 3.0)来增强其作用。

 

图片

可以看到,DISTILL_LOSS_WEIGHT(知识蒸馏损失权重),设置更大的值可以提升知识转移效果。这样生成的学生模型精度相对较HARD_LOSS_WEIGHT(原始标签损失权重),确保模型仍然关注真实标签。因此,建议保持DISTILL_LOSS_WEIGHT > HARD_LOSS_WEIGHT。这里为了方便对比我们设置几组不同的软标签和硬标签后对整个模型在精度、推理时间及参数数量上的差异对比。

  # 对比三组超参数   settings = [       {'name': '软标签主导', 'distill_weight': 3.0, 'hard_weight': 0.7},       {'name': '硬标签主导', 'distill_weight': 1.0, 'hard_weight': 3.0},       {'name': '均衡组合', 'distill_weight': 1.0, 'hard_weight': 1.0},   ]

相应的运行结果如下:

图片

从如上图所示的运行结果中,我们可以看到:学生模型在某些设置下甚至可以精度超过教师?这是为什么呢?

其实这在知识蒸馏中是常见现象,尤其是当教师模型 overfit 或比较大、复杂但没完全拟合训练集时更容易出现。学生模型在蒸馏中,通过学习教师的软目标,获得了更好的泛化能力。训练 epoch 虽然短(比如你只跑了几个 epoch),但学生模型更快达到收敛。

当然,教师模型过拟合≠ 教师“没用”。虽然教师模型可能在训练集上过拟合,它的输出(尤其是 soft target)仍然包含了很多有价值的信息。这里我们举例说明一下:

教师模型在训练集中学到了“高阶模式”和“类别之间的关系”,即使它记得太死(过拟合),这些模式在学生学习过程中仍然可以变成“泛化指导”。比如教师对某张模糊图片输出 [dog: 0.5, wolf: 0.4, cat: 0.1],教师模型通常采用是非题的判断方式把它选错了为dog,但是学生在继承过程中会通过提升软标签的权重接受教师模型关于这种不确定性的分布,即告诉学生如下事实:

• “虽然它是 dog,但狼也很像,你要小心这些边界情况。”

这就可以理解为:学生学得更「聪明」,而不是更「全」。教师给出模糊但有指导价值的信息,学生用更轻量模型「巧妙泛化」了这些信息。

同时,我们还可以看到,均衡组合精度(71.78%)不仅高于软标签(69.46%)主导,且精度还高于硬标签(70.50%)主导,那么,为什么均衡组合精度高于硬标签单独使用呢?

这是因为硬标签(交叉熵CrossEntropy)让学生模型学习 ground truth 的分类。软标签(KL散度)是来自教师模型的输出概率分布,提供了「类别间相似度」信息,比如「狗」和「狼」相似,而不是只知道是不是「狗」。均衡组合 = 学生同时学习「明确目标」(硬标签)+「教师思考路径」(软标签),这比只学硬标签要更全面、更泛化,从而导致最终精度可能超过单纯的硬标签模型,甚至超过教师模型本身。

此外,还有一个运行结果也可以看到一个特别的现象:即为什么均衡组合的推理时间比硬标签主导还更短? 其实,这个和精度关系不大,主要是运行时的随机波动,原因包括:

○批处理数据载入的 IO 差异(CPU预处理稍慢时影响整体时间)

○CUDA kernel 启动、内核调度的系统层随机性

○模型 forward 时 Tensor 在缓存里的命中情况不同

也就是说,两个模型的结构和参数量一样,时间差 <0.05s 是正常波动,不必认为是超参造成的。

最后,我们也关注到为什么不同软硬标签权重组合(比如 soft 标签主导、hard 标签主导、均衡组合)下,最终的模型参数量和推理时间都差不多?

原因一:软标签和硬标签只是训练过程中的损失函数组合方式不同,但教师模型结构没变,学生模型结构也没变。这就意味着模型的参数量固定(比如学生 ResNet-18 固定就是 11.18M 参数)。推理时间也基本一致,因为模型大小、网络结构、前向路径均没变化。

原因二:软/硬标签只是训练目标,不影响推理结构。换句话说,你训练的时候怎么学的,和你推理时怎么跑的,是两码事。更通俗一点讲,不同权重组合的效果类似于“不同学习方法”——有的人靠看书记,有的人靠做题练,但最终考试的时候,大家都是用一个标准试卷答题,花的时间和写字手速没关系。

5.学习率调节

学习率(Learning Rate)主要控制的是每次参数更新的“步长”。也就是在梯度方向上走多远。具体来说:

假设你当前参数是θ, 损失函数的梯度是 ∇L(θ),那么每一步更新就是:θnew=θ−η⋅∇L(θ),其中 η 就是学习率。

不同学习率的效果对比如下:

图片

这里我们设置三种不同的学习率做一个简单的模型对比实验,如下图:

 

图片

这里我们可以明确的观察到为什么训练时间差别不大而推理精度变化大呢?

训练时间: 每轮训练的数据和模型参数量是一样的,只是学习率影响的是权重更新的幅度,不是计算量。因此在epoch 数相同的情况下,训练时间变化不大。

推理时间: 推理是前向计算,模型结构固定,所以不同学习率训练得到的学生模型在参数量一致时,推理时间也几乎不变。

从如上运行结果,可以看出当学习率为0.0001太小,训练进展缓慢,模型几乎没学到有效知识。而当学习率增大为0.001时,比较理想的学习率,能够稳定收敛。当学习率继续增加到0.01时,更高的学习率加快了收敛速度,效果更好,但也要警惕不稳定或震荡风险。

如下图所示,可以看出当学习率设置的不合理时,会引起相应的震荡。

图片

图片

从结果来看,我们可以得出几个结论:

学习率过小(如 0.0001)时,训练效果较差,模型没有充分学习。

学习率适中(如 0.01)时,精度最高,说明模型有效收敛。

学习率过大(如 0.1)时,训练极度不稳定,模型几乎学不到东西。

原文地址:https://mp.weixin.qq.com/s/an8ftvQ5ET2FTifhap5okA