知识蒸馏中的温度参数(1)

1,065 阅读5分钟

1、知识蒸馏简介

  知识蒸馏(Knowledge Distillation)是一种模型压缩方法,其来自于15年Hinton的一篇论文《Distilling the Knowledge in a Neural Network》
  其中主要的思想则是把老师的知识传授给学生, 使得学生模型能够拥有更丰富的知识。这里的老师可以指代的是大模型,学生一般就意味着小模型。

erDiagram
Teacher ||--|{ Student : Knowledge

由于其理论简单,效果甚佳,在工业界也得到了广泛的应用。现在就让我介绍一下这篇论文

2、背景介绍

  近几年随着transformer各个变种的提出,比如bert, roberta, GPT。当前人工智能领域已经迈入了大模型(泛指参数量大)时代。甚至想Google在2021年提出的switch transformer参数量级达到了恐怖的1.6万亿级别。但由于模型容量大,导致训练耗时,推理速度慢,占用内存空间高,使其几乎无法在工业界产生应用。无法应对工业界高并发的环境。 针对这些瓶颈:

  1. 推理速度慢
  2. 训练速度慢,占用内存高,cpu消耗大 Hinton提出了“知识蒸馏”这种模型压缩方法(在保证性能的前提下减少模型的参数量)。即把大模型学到的知识“传授”给小模型,使得小模型在参数量不足的情况下,也能够表现出比较好的效果。

3、具体方式

这里会有两种角色模型
1、老师模型(Teacher-Net):作为“知识”的输出者,会将自己从数据中学习到的知识,“传授”给学生模型。其特点就是,参数量大,结构复杂,还可以由一个或者多个大模型共同组成。当然,我们不需要对老师模型做任何限制,但会有训练数据作为输入输出,假设,输入为X, 输出是P
2、学生模型(Student-Net):作为“知识”的接收者,会吸收老师传过来的“知识”。其特点一般是参数量级小,结构简单。但对其也不会有特定的限制。 假设输入也为X, 输出为Q

以上是知识蒸馏的前置条件,当前我们会后续的介绍限定在分类任务,便于大家理解。

3.1 理论知识

  在传统的深度学习的分类任务中,我们会通过已经标注的数据,用以训练模型,使得模型具有一定的泛化能力。在之后遇到类似的样本时,模型能够正确的识别这个样本的类别。
  那么在训练的过程中,怎么去学习到样本的标签呢?
  我们当前的方法是模型的输出经过softmax形成的概率分布,去拟合当前的数据标签,然后利用反向传播来更新模型中的参数。

image.png 其中上面的柱形图即是模型输出经过softmax形成的概率分布。
下面的图即是样本数据的标签

在知识蒸馏中,我们也称上面的概率分布为:soft target或者soft label。下面的分布为hard target或者hard label

所以最终学生模型输出的概率分布,不仅要拟合老师模型的soft target,还需要拟合样本真正的标签,hard target.

疑问:为什么学习soft target就能够使得学生模型学习得更好的?

解答:在真正的分类任务中,除了正标签重要之外,负标签也很重要,他也会带来大量的信息。但是传统的分类训练会直接把所有的负标签同等对待。换句话说,知识蒸馏这种训练方式给每个样本都带来了比传统训练方式更多的信息量。
举个例子: 我们有一个手写字体识别分类任务,0~9的数字中,7和1写起来很像,但是7和5就很不像,GroundTruth只告诉了我们,这个图片是7,但是logit还告诉了我们:这个图片大概率是7,小概率是1,几乎不太像其他数字。这其中携带了的信息量,也就是我们希望Student-Net学到的知识

3.2 知识蒸馏的具体过程

知识蒸馏的训练过程如下:

image.png   从上可以看出,在训练过程中,Student-Net不仅要学习Teacher-Net的知识,还要学习样本标签的知识。(注意:Teacher_Net是已经在样本集上训练好了的模型)

3.2.1 损失函数

具体表示如下:
1、总的损失函数如下:
   Loss=αLsoft+βLhard(1)Loss = \alpha L_{soft} +\beta L_{hard}\qquad\qquad\qquad\qquad\qquad(1)

其中α和β都是超参数,需要自行设定。

2、Lsoft表示Student_Net与Teacher_Net的损失函数。其定义表达式如下:

Lsoft=jNpjTlogqjT(2)\qquad \qquad L_{soft} = -\sum\limits_{j}^Np_j^T\log{q_j^T}\qquad\qquad\qquad\qquad\qquad\quad(2)
其中: pjT=evj/TkNevk/T,qjT=euj/TkNeuk/Tp_j^T = \dfrac{\mathrm{e}^{v_j/T}}{\sum\limits_k^N \mathrm{e}^{v_{k/T}}},\qquad\quad q_j^T = \dfrac{\mathrm{e}^{u_j/T}}{\sum\limits_k^N \mathrm{e}^{u_{k/T}}}
  vj:表示Teacher_Net在第j个类别的logit
  uj:表示Student_Net在第j个类别的logit
 T:代表的是温度常数
注:注意区分v和u, 不要看混淆了。另外logit是模型在经过softmax计算前的输出,不是softmax计算后的输出

3、Lhard表示Student_Net与真实样本标签的损失函数。

其表示如下:
Lsoft=jNcjlogqj1(3)\qquad \qquad L_{soft} = -\sum\limits_{j}^Nc_j\log{q_j^1}\qquad\qquad\qquad\qquad\qquad\quad(3)
注意:q的上角标是1, 表示温度常数设为了1
其中C表示真实标签,而 qj具体如下
qj1=eujkNeukq_j^1 = \dfrac{\mathrm{e}^{u_j}}{\sum\limits_k^N \mathrm{e}^{u_k}}
然后根据Loss进行训练即可, 以上便是知识蒸馏的全部内容,后续会再讨论关于温度常数的选取,以及不同值的温度常数所带来的影响。

4、参考如下

# Distilling the Knowledge in a Neural Network
# 知识蒸馏是什么?一份入门随笔
# 经典简读知识蒸馏(Knowledge Distillation) 经典之作
# 一文搞懂知识蒸馏


注意:因为笔者的能力有限,如有理解错误或考虑不周的地方,欢迎指出。