Introduction
在文章开头作者举了一个非常形象的例子说明知识蒸馏的重要性。在自然界中,许多昆虫存在幼虫形态和成虫形态,幼虫形态方便昆虫从环境中提取能量和营养,而成虫形态更多是为了满足旅行和繁殖的需求。
在大规模的机器学习不同过程中,我们同样存在不同的需求。模型训练的开始阶段,我们更多地希望训练出的模型能够具有出色的能力,为此我们愿意花费很多时间训练一个笨重但是效果好的集成模型。但是当我们训练出这样一个大型的模型,我们的需求就变成了:能否得到这样一个模型,它的规模相对小,但是我们能够将大型模型学到的知识迁移到小模型中,使其同样具有不错的效果。
Rich Caruana 等人指出,一个大型集成模型学到的知识可以被迁移到小型的单一模型。这个知识迁移的过程,我们称之为知识蒸馏。
我们通常将知识定义为已训练模型的参数,这一点使得我们很难进一步发现如何改变模型的形态但是保持模型已经学到的知识。实际上,知识的一个更抽象的认识是,将输入向量映射为输出向量的方法。从这个思路出发,输出向量一般就是输出的softmax分布,那么知识蒸馏的过程实际上就是训练小模型学习大模型的输出向量的过程,即让小模型和大模型在输入相同的情况下输出的softmax分布相近。用更形象一点的语言解释知识蒸馏的思路就是:小模型放弃直接从训练集中学习数据的分布,转而学习大模型学习到的 softmax 分布。
因此,在训练集数据相对noisy的条件下,我们可以放弃直接去学习其数据分布,转而学习另一个模型已经学习好的分布。大家可能对这一过程存在许多疑问。这样做的理由是什么呢?大模型无论结构有多复杂,其预测结果相对ground truth 始终是存在误差的,那么小模型将其预测结构作为 ground truth,岂不是在误差的基础上又引入了新的误差,造成误差传递?
其实不是这样的,误差传递传递是存在的,但知识蒸馏的初衷就是解决误差传递的问题,即得到一个小模型,其结构相对简单但却效果能媲美结构复杂的大模型。
以分类任务举例,我们知道,模型通过训练能够区分很多个不同类别,通常分类模型的目标函数是最大化正确答案的的log概率值。但实际上其中还存在一些有用信息,即错误答案概率值的相对大小。这些概率值的相对大小告诉了我们模型是怎样进行判断的。
问题的关键点在于,结构复杂的大模型,由于其相对出色的分类能力,往往会以很高的置信度(log概率值)去进行判断。例如下面三个模型,模型1 和 模型2 都能正确将样本分类为狗,但是模型1是以非常高的置信度输出结果的,它认为该样本是其他两个类别的概率是0,这个样本就只能是狗;与此同时,模型2虽然得出了正确的结果,但其认为该样本还是存在是其他类别的可能性。
| 模型 | 预测(猪 蚂蚁 狗) | 真实 |
|---|---|---|
| 模型1 | 0.0 0.0 1.0 | 001 狗 |
| 模型2 | 0.2 0.3 0.5 | 001 狗 |
| 模型3 | 0.0 1.0 0.0 | 010 蚂蚁 |
那么问题就来了,我们不希望大模型得到这种高置信度的输出(虽然其结果是正确的),我们希望的是大模型的输出向量更加接近与模型2,在得到正确结果的同时。
两者的区别就在于:模型1的结果在得到正确结果的同时,忽略了错误答案概率值的相对大小。
错误答案概率值的相对大小同样包含着关于问题的一些信息,比如,从模型2的输出结果我们可以知道,该样本是狗,不是蚂蚁,但是更不可能是猪。与之相比,从模型1的预测结果,我们无法推断模型对于样本是猪和蚂蚁的比较结果。
当然,在实际分类问题中,模型很少输出 0 0 1这样的结果。但不可否认的是,大模型在以很高的置信度(log概率值)去进行判断的同时,其他类别的概率值会很小,趋近于零。
更重要的一点是,这种以很高的置信度(log概率值)去进行判断的大模型,难以进行知识蒸馏。
我们在文章开头提到,知识的一个更抽象的认识是,将输入向量映射为输出向量的方法。知识蒸馏的过程就是将大型集成模型学习到的知识迁移到小型的单一模型,即尽可能使得小模型的输出向量接近大模型的输出向量。我们回想softmax函数,当输出的概率值趋近于0的时候,梯度也近似消失了。结果就是,小模型难以学习到大模型的输出向量。知识蒸馏过程的误差传递更明显了,这显然是我们不能接受的。
在论文中,类似模型2的输出结果作为小模型的标签,我们叫做“soft target”。如何训练大模型,使其输出结果能够相对soft呢?
Approach
分类任务中的logits
logits实际上是在神经网络的分类任务中,softmax函数之前的一层输出值。它们是未归一化的分数或激活值。通过softmax函数,这些logits会被转换为概率分布,使它们的总和为1。假设我们有一个简单的神经网络模型,用于图片分类任务。这个模型的最后一层有3个神经元,对应3个类别:猫、狗和兔子。 Logits计算: 在网络的最后一层(输出层),模型输出的logits可能是:
- Cat: 2.5
- Dog: 1.0
- Rabbit: 0.5 这些输出值称为logits,它们表示模型对于每个类别的未归一化的信心值。 通过Softmax转换为概率: 我们使用softmax函数将这些logits转换为概率分布。softmax的公式假设温度参数为1:
对于上面的logits,概率计算如下:
所以,通过softmax函数,logits转换为概率如下:
- Cat: 0.74
- Dog: 0.16
- Rabbit: 0.10 在这个示例中,logits (2.5, 1.0, 0.5) 是模型在应用softmax之前的输出值。通过softmax,这些未归一化的分数被转换为一个概率分布(0.74, 0.16, 0.10),总和为1,这表示模型的预测置信度。这些概率可以用来决定模型认为哪一类是最有可能的,通常选择最高概率对应的类别作为最终的预测结果。
Temperature
我们通过 softmax 函数将模型输出的 logit 转化为 probabilities。在此,作者在原有的 softmax 函数上引入温度的概念,以此得到便于小模型学习的 soft target。
在公式中,T 表示温度。通常情况下我们见到的 softmax 函数就是 T=1 的特殊情况。容易证明,当T趋近于0时,其最后输出的概率向量会更偏one-hot;当T趋近于无穷大时,其输出的概率向量中各概率会更加接近,即更“软”。
我们将前文提到的大模型称为 teacher model,小模型称为 student model。知识蒸馏的过程可以描述为:
- 将已经训练好的 teacher model 中的 softmax 温度升为 T,将训练数据重新输入 teacher model 得到 soft target。
- 在 student model 中分别将温度设置为 T 和 1,得到soft prediction 和 hard prediction, 两者分别与soft target 和 ground truth 计算 loss,两种loss(distillation loss + student loss)按照某种方式加权求和作为最终的loss。
Matching logits is a special case of distillation
此时,我们还有一个疑问。既然知识蒸馏的过程是让 student 模型学习 teacher 模型输出的 softmax 分布,那么为什么loss要采用交叉熵函数这种分类任务的损失函数,而不把整个任务当着一个回归任务,直接预测teacher 模型输出的 softmax 分布(即 matching logits)?
在文章中,作者证明了 matching logits 的过程实际上是蒸馏过程的一种特例,在公式2中我们可以看到,当 T 的数值相比logits足够大的时候,公式2可以变形为公式3(分子趋近与e的零次方也就是1,分母趋近于N),同时,我们假设所有的logits均值为0,那么可以得到公式4。因此,在这种情况下,蒸馏的过程就等价于最小化 。
作者证明了当温度足够高的时候,蒸馏的过程实际上等价于matching logits。然而在实际应用中,我们并不需要 student 模型的 softmax 分布与 teacher 模型输出的 softmax 分布完全一致。这是因为 teacher 模型中的logits中,一些极度偏离均值的logits几乎完全不受损失函数的约束,这些 logits 实际上引入了很多噪声。因此,当student 模型结构太简单而不足以完全学习到 teacher 模型的知识的时候,选取一个适中的温度能够取得更好的效果,也就是说,忽略那些偏差很大的 logits 在实际应用中是有利的。
Experiments
作者在 MNIST 数据集上训练了一个规模相对较大的神经网络,这一模型在测试集上有67个分类错误的样例。另一个规模较小的模型在测试集上有146个分类错误的样例。但是当小模型进行知识蒸馏后,分类错误的样例减少为74个。
在此基础上,作者将所有手写字为3的样例从 transfer set 中删除,这就表明对于蒸馏后的模型,3是一个它从来没有见过的数字,尽管如此,蒸馏模型只产生了206个测试失败的样例,其中133个是在测试集的1010个3上。大多数错误是由于3类的学习偏差太低而造成的。如果这个偏差增加3.5(这将优化测试集的整体性能),则经过蒸馏的模型将产生109个错误,其中14个是在3上。有了合适的偏差,经过蒸馏的模型在测试3中获得了98.6%的正确率,尽管在训练中从未见过3。如果迁移集只包含训练集的7s和8s,则蒸馏模型的测试误差为47.3%,但当7和8的偏差减少7.6以优化测试性能时,测试误差降至13.2%。
此外作者在语音识别数据集上同样进行了实验,由于笔者对这个方向了解不是很深入,因此知识简单看了一下实验结果,可以看到蒸馏后的模型很大程度上保持了集成模型的准确率。