知识蒸馏时,由于我们已经有了一个泛化能力较强的Net-T,我们在利用Net-T来蒸馏训练Net-S时,可以直接让Net-S去学习Net-T的泛化能力。
一个很直白且高效的迁移泛化能力的方法就是:使用softmax层输出的类别的概率来作为“soft target”。
【KD的训练过程和传统的训练过程的对比】
Hard-target:原始数据集标注的 one-shot 标签,除了正标签为 1,其他负标签都是 0。 Soft-target:Teacher模型softmax层输出的类别概率,每个类别都分配了概率,正标签的概率最高。
软目标——一种监督信号的形式
教师模型以标注生语料、构造有标签数据的形式,来为学生模型提供监督信号。4.2节的地(2)步里提到,这种监督信号,类别的概率分布,也就是“软目标”(Soft Target)。为什么不直接使用类标签(对神经网络来说就是独热编码),也就是“硬目标”(Hard Target),来训练学生模型呢?
独热编码是一种“信心很高”的监督信号,认为一个样本“只属于”某个类别(这里记为A类),不可能属于“非A类”。在使用交叉熵损失函数训练的时候,独热编码里只有A类对应的误差会贡献梯度和信息。
而软目标中,样本属于非A类的概率不一定等于零,还保留了样本属于各个类别的概率相对大小。这样,使用交叉熵损失函数训练模型时,软目标的非A类维度,也会贡献梯度;软目标中各个概率值的相对大小,则体现了各个类别之间的相关性等信息。理论上,软目标的信息量是超过硬目标的。
当然,教师模型通常还是非常“自信”的,判断样本类别的结果可能是
这个向量里的元素非常接近0或1,和独热编码没啥区别;正确类别之外的概率值,经常是0.00001或0.00000001,差异太小了,不利于模型学习类别之间的关系规律。
这时候,我们可以用一种策略,把元素的取值向0.5“推一推”:
假设输出层的logits,即最后一个线性变换层的输出为
,经典的概率计算方式为
。我们给z除以一个数,即
,就可以让概率值离0或1稍微远一点。
T这个参数的全称是Temperature,即温度。T是一个大于1的数(一般是整数)。T越大,概率取值被“挤压”得越厉害、越靠近0.5——正确类别之外的概率值差异被放大的程度越大,越有利于模型学习监督信号里非A类的信息。说实话,我至今没有领会到知识蒸馏中,“蒸馏”这个词语与化学老师教的蒸馏之间的相似性,这个参数T也没有让我联想到“温度”。不过这不影响Hinton这项工作的价值。
总之,“软目标”提升了学生模型对监督信号中,非A类维度对应的信息,显著提升了学生模型的学习能力。
另一种方法是直接比较logits来避免这个问题。具体地,对于每一条数据,记原模型产生的某个logits是v,新模型产生的logits是 z,我们需要最小化
文献[2]提出了更通用的一种做法。考虑一个广义的softmax函数
损失函数
不论软目标有多好,多不能掩盖硬目标拥有独特信息的事实。硬目标还是有学习价值的——我们可以使用一种多任务学习的方式,将软目标和硬目标带有的信息都教给学生模型。该学习方式下的损失函数为
- 传统training过程(hard targets): 对ground truth求极大似然
- KD的training过程(soft targets): 用large model的class probabilities作为soft targets
KD的训练过程为什么更有效?
softmax层的输出,除了正例之外,负标签也带有大量的信息,比如某些负标签对应的概率远远大于其他负标签。而在传统的训练过程(hard target)中,所有负标签都被统一对待。也就是说,KD的训练方式使得每个样本给Net-S带来的信息量大于传统的训练方式。
softmax函数
先回顾一下原始的softmax函数:
但要是直接使用softmax层的输出值作为soft target, 这又会带来一个问题: 当softmax输出的概率分布熵相对较小时,负标签的值都很接近0,对损失函数的贡献非常小,小到可以忽略不计。因此**"温度"**这个变量就派上了用场。
下面的公式时加了温度这个变量之后的softmax函数:
- 这里的T就是温度。
- 原来的softmax函数是T = 1的特例。 T越高,softmax的output probability distribution越趋于平滑,其分布的熵越大,负标签携带的信息会被相对地放大,模型训练将更加关注负标签。
4. 关于"温度"的讨论
【问题】 我们都知道“蒸馏”需要在高温下进行,那么这个“蒸馏”的温度代表了什么,又是如何选取合适的温度?
4.1. 温度的特点
在回答这个问题之前,先讨论一下温度T的特点
- 原始的softmax函数是
T=1时的特例,T<1时,概率分布比原始更“陡峭”,T>1时,概率分布比原始更“平缓”。 - 温度越高,softmax上各个值的分布就越平均(思考极端情况: (i)
T =∞, 此时softmax的值是平均分布的;(ii)T->0,此时softmax的值就相当于argmax, 即最大的概率处的值趋近于1,而其他值趋近于0) - 不管温度T怎么取值,Soft target都有忽略相对较小的Pi 携带的信息的倾向
4.2. 温度代表了什么,如何选取合适的温度?
温度的高低改变的是Net-S训练过程中对负标签的关注程度: 温度较低时,对负标签的关注,尤其是那些显著低于平均值的负标签的关注较少;而温度较高时,负标签相关的值会相对增大,Net-S会相对多地关注到负标签。
实际上,负标签中包含一定的信息,尤其是那些值显著高于平均值的负标签。但由于Net-T的训练过程决定了负标签部分比较noisy,并且负标签的值越低,其信息就越不可靠。因此温度的选取比较empirical,本质上就是在下面两件事之中取舍:
- 从有部分信息量的负标签中学习 --> 温度要高一些
- 防止受负标签中噪声的影响 -->温度要低一些
总的来说,T的选择和Net-S的大小有关,Net-S参数量比较小的时候,相对比较低的温度就可以了(因为参数量小的模型不能capture all knowledge,所以可以适当忽略掉一些负标签的信息)