前言
这篇文章是基于蒸馏的方法,提高ViT的性能,根据最近transformer相关文章的实验对比以及引用可以发现,这篇文章基本上是ViT以后出现的比较早的文章了。
本文已参与「新人创作礼」活动,一起开启掘金创作之路。
出发点
在之前的博文Transformer主干网络——ViT保姆级解析中总结了一下ViT留下的坑,本文是从ViT的留下的第五个坑下手的:
ViT需要现在JFT-300M大型数据集上预训练,然后在ImageNet-1K上训练才能得到出色的结果,这借助了额外的数据。
个人认为这也是留下的一个比较难或者说比较没有解决方向的坑,本文作者主要是通过蒸馏的方法来提高ViT的性能让模型不需要大规模的数据pretrain也能达到甚至超过卷积神经网络的性能。
为什么可以蒸馏来提高呢?
因为ViT的作者实验了用大数据集pretrain然后finetune可以达到比较好的效果,所以说明ViT结构是完全有能力做到更好只是当前的方法或者训练技巧不能让ViT将知识学习充分。所以作者想到了用蒸馏(我猜的哈哈哈哈)
主要贡献
- 只是用imagenet训练不需要额外的数据,并且用单机四卡训练三天就完成了训练。其中DeiT-S和DeiT-Ti参数量可以对比resnet50和resnet18
- 作者提出了一个基于token的蒸馏策略。
完整结构
其实基本上就是ViT的结构,这里直接截图forword函数的代码,圈起来的部分是对比ViT新增的:
究竟怎么做的,先看一下作者总结的一些蒸馏的知识。
soft distillation
软蒸馏是用学生模型经过了softmax层输出的分数和教师模型经过了softmax层输出的分数计算KL散度来进行蒸馏的。
Tips:KL散度
离散型的KL散度其实挺简单的,他在做的事就是比较两个特征的相似程度,公式如下(不看也行,记住就是比较两个特征的相似程度就行了):
比如:
P = [0.2, 0.4, 0.4]
Q = [0.4, 0.2, 0.4]
KL = 0.2 * log(0.2/0.4) + 0.4 * log(0.4/0.2) + 0.4 * log(0.4/0.4)= 0.13862943611198905
对于软蒸馏,设为教师模型输出的分类分数,为学生模型输出的分类分数,是分类的ground truth,是比例缩放的系数,所以第一部分的分类损失就是:
分类损失=
第二部分的蒸馏损失就是:
蒸馏损失=
这一项的话如果和一模一样的话损失就是0,反之越不像越大。
然后用一个参数lambda来调节两个损失的占比就得到了完整的损失:
Hard-label distillation
硬蒸馏就更简单了,直接用teacher模型输出的分类结果监督学生模型:
是教师模型预测的类别。
本文提出的Distillation token
这其实就是本文的主要贡献了,这个token的实现也很简单,就是和cls token一样又拼接了一个token用来计算蒸馏损失,这里的蒸馏损失使用的是Hard-label distillation,也就是说:
图中的1对应Hard-label distillation中的
图中的2对应Hard-label distillation中的
所以cls token是用groundtruth来监督的而dist token使用教师模型预测的结果来监督的。
作者的实验结果发现:
- 最初cls token和dist token区别很大,余弦相似度为0.06
- 随着学生模型和教师模型互相传播和学习,网络逐渐变得相似,到最后一层,余弦相似度为0.93
实验结果
不同教师模型蒸馏的结果:
==作者发现使用卷积网络作为教师网络能够比使用Transformer架构取得更好的效果==
不同蒸馏方法的对比:
==实验证明:Hard-label distillation能够取得更好的结果。==
完整对比结果: