【知识蒸馏】利用神经网络蒸馏知识

1,317 阅读3分钟

论文链接

研究背景

一个通用的能提高几乎所有机器学习算法性能的方法就是在相同的数据上训练许多不同的模型,然后取这些模型预测结果的均值。但是,使用这样的预测模型是很麻烦的,特别是当单个模型是大型神经网络时,计算成本过于高昂导致难以部署到大量的用户机器上。Caruana和他的合作者证明,可以将知识压缩到一个单一模型中让模型更容易部署,这篇文章进一步使用了一种不同的压缩技术来延伸这种方法,称之为“蒸馏”。

蒸馏

对于分类问题,神经网络模型的最后会通过“softmax”输出层转换logit(用ziz_i表示)来产生类别概率qiq_i

image.png

此处的TT为温度参数(温度参数介绍见链接)一般设置为1。若增大TT值,可以得到更“软”(平缓)的类别概率分布。下图为不同的TT值得到的概率曲线:

image.png

使用温度参数的原因 当softmax函数输出的概率分布熵较小时,负标签的值都接近于0,通过调节TT值的大小,TT值增大概率分布变得更平缓,也就放大了分布熵,让模型训练更关注负标签。

在最简单的蒸馏形式中,通过在迁移集上训练蒸馏模型并在迁移集中的没一个案例使用软目标分布来迁移知识,迁移集在其softmax中是使用带高温度参数的笨重模型产生的。训练蒸馏模型时使用相同的高温度参数,但经过训练后其使用的温度参数调整为1。

在知识蒸馏中会先训练一个泛化能力较强的复杂模型再通过复杂模型蒸馏训练简单模型,即将复杂模型的知识“迁移”到简单模型中(这里对应了迁移学习的概念),一个简单的迁移方法就是将复杂模型softmax输出层的输出的类别概率qiq_i作为“soft target”。

Soft target的概念 和soft target对应的是hard target,hard target表示一般神经网络训练的目标,其训练目标在于分类准确率尽可能的高(one-hot编码)。soft target对应的是带有TT的目标,其训练目标在于要尽量接近于复杂网络加入TT之后的概率分布。soft target比hard target包含的知识更丰富,hard target只包含类概率,soft target还包含负标签对应的概率(即复杂模型输出的概率向量),因此使用soft target可以获得更好的简化模型。

知识蒸馏的一般方法

知识蒸馏分为两个步骤:

  1. T=1T=1时训练复杂网络模型,也称为教师模型,用Net-T表示。
  2. 在高TT值下将Net-T的知识蒸馏到简单模型(学生模型)Net-S中。

蒸馏的目标函数如下所示,为两个目标函数的加权平均,其中α,β\alpha, \beta为两个权重值。

L=αLsoft+βLhardL=\alpha L_{soft}+\beta L_{hard}

Net-T和Net-S同时输入数据集,用Net-T生成的带高TT值的softmax层输出的概率分布作为soft target,在相同的TT值下,Net-S的softmax层输出的概率分布和soft target的交叉熵共同组成LsoftL_{soft},如下所示。

Lsoft=jNpjTlog(qjT)L_{soft}=-\sum^N_jp^T_jlog(q^T_j)

其中,pjTqjTp^T_j,q^T_j分别为Net-T和Net-S在温度参数TT时softmax层输出的第i类的概率值,NN为总标签数。

LhardL_{hard}则是Net-S在T=1T=1时的softmax层输出和ground truth(可以理解为训练集样本的标签\标准答案)的交叉熵:

Lhard=jNcjlog(qj1),qi1=exp(zi)kNexp(zk)L_{hard}=-\sum^N_jc_jlog(q^1_j), q^1_i=-\frac{exp(z_i)}{\sum^N_kexp(z_k)}

其中,cjc_j为第i类的ground truth值,ci{0,1}c_i\in\{0,1\},正标签为1,负标签为0,ziz_i为Net-S的logits。