categories: 论文阅读笔记
tags: 小样本学习
summary: 这篇文章提出了一种半监督小样本学习的迁移学习模式,该方法能够重充分利用有标签的基类和无标签的新类上的信
摘要
这篇文章提出了一种半监督小样本学习的迁移学习模式,它能够充分利用基类和新类上的信息。包括三个部分:1.在基类上预训练的特征提取器;2.使用特征提取器来初始化新类的分类器的权重;3.采用半监督的学习方法来进一步提高这个分类器。作者提出了一种新的方法叫做MixMatch,即利用imprint和MixMatch来实现了这三个部分。
引言
作者首先总结了小样本学习的两大门派:元学习方法,迁移学习方法。
元学习的方法 采用episode训练策略。episode是一种类似于batch的机制,它是从数据集中采样出来的一部分数据,其中只包含极少的基类中的数据,这样模拟了测试时只有极少的标注数据的情形。episode中的标注数据被分成两个部分,即support set和query set。support set用于构建模型,query set用于评估模型的性能。
迁移学习方法 这篇文章的灵感来自于迁移学习的方法,作者企图利用基类和新类的无标签数据来预训练一个模型,然后利用这个模型来学习一个新类的分类器。
主要贡献
1、提出了一种半监督小样本学习的迁移学习模式,它能够充分利用类和无标签新类数据的信息
2、开发了一个叫做TransMatch的方法,它综合了基于迁移学习的小样本学习方法的优势和半监督学习方法的优势
3、在流行的小样本学习数据集上进行了广泛的实验,并且展示了该方法确实能够充分利用无标签数觉得信息
相关工作
1.小样本学习
小样本学习的相关工作可以分为两大类,一类是基于元学习方法的,一类是基于迁移学习方法的。
基于元学习的方法: 基于元学习的小样本学习又称为学会学习,它的目的在于学习一种范式,能够适用于识别只有少量样本的新类的任务场景。元学习有两个阶段组成,元训练阶段和元测试阶段。跟通常的训练阶段和测试阶段基本类似,只不过训练阶段采用episode的策略,测试阶段每个类只有极少的样本。元学习方法又可以分为两类:1.基于指标的方法;2.基于优化的方法。
基于指标的方法 的目的在于学习一个好的指标,用该指标来衡量support set和query set之间的距离,或者说来衡量两者之间的相似性。
基于优化的方法 的目的在于设计一个优化算法,使得训练阶段信息能够适用于测试阶段。(我认为这就是我们常见的那些训练模型的方法)
基于迁移学习的方法: 迁移学习不使用episode训练策略,它现在基类的大量的标注数据上预训练模型,然后让这个预训练模型适应于小样本的新类的相关任务。
2.半监督学习
半监督学习能够学习有标签的数据和无标签的数据。它主要分为两大类:一类是自洽正则化方法,一类是熵最小化方法。
自洽正则化方法 主要就是通过加噪声或者通过数据增强来进行正则化。
熵最小化方法 的目的在于减小无标签数据的熵。
本文中使用的MixMatch方法综合了不同类型的自洽正则化方法和熵最小化方法,可谓性能超强。
3.半监督的小样本学习
当新类中的样本数量很少的时候,非常容易想到的就是利用无标注的数据来提高模型的性能。这样的想法就导致了半监督的小样本学习方法,这方面的工作已经有很多了,但是大部分都是基于元学习的。元学习的episode训练策略直接和半监督的学习方法直接进行整合不太合适,同时,迁移学习方法可以达到和元学习方法同样的性能,这是作者的灵感来源。基于元学习的半监督小样本学习还有以下几点缺点:1.当前的性能还不是最优的;2.更强大的方法,像MixMatch无法被整合进去;3.在测试的时候直接使用半监督的学习方法可能会导致更坏的性能。
问题定义
数据集 :基类数据集,每个类包括许多带标注样本。其中包含的类称为。:新类数据集,每个类包括少量的带标注样本,但是数据集中含有大量的无标注样本。其中包含的类称为。新类中的类和基类中的类是不相交的。
作者的目标 是要学习一个鲁棒性的分类器,主要利用新类中少量的带标注的样本和大量的无标注样本。使用基类作为附属数据集。
方法
作者提出的方法是:首先利用基类的数据来进行模型的预训练。然后,将这个预训练模型作为一个特征提取器,来提取新类中少量的带标签样本的特征。然后将这些特征直接作为新类分类器的初始权重,在这个基础上来做进一步的微调。
预训练的特征提取器 利用基类中的数据来训练这个特征提取器。这跟迁移学习的预训练的目的是一样的,尽可能的提取基类中的知识,然后迁移到到新类的学习上。
Imprint权重 从新类中采样N个类,每个类采样K个带标注样本,这就形成了N-way K-shot问题。这部分回答两个问题:1.如何进行Imprint权重?2.分类器实际上在做什么? Imprint权重的核心公式见公式1
下标c表示第c个类,表示上一阶段得到的特征提取器。表示第c个类的第k个样本。很显然,这就是提取N-way K-shot样本的特征的平均值,将这个平均值作为权重。
分类器实际上在计算一个相似度。见公式2
从公式2和公式3可以看出,新类的分类器实际上是在计算样本x的特征和k-shot的平均特征之间的余弦相似度。取相似度最大的类最为预测的类。但是,这里仅仅是设置了分类器权重的一个初值,在下一阶段还要进行微调。
微调阶段 作者使用MixMatch的方法来微调分类器。一方面是由于MixMatch在半监督学习任务上具有超强的性能,另一方面是因为MixMatch能够很好的利用无标注数据。一个批量的带标注数据记为,一个批量的无标注数据记为。
无标注数据的标签可以通过第二部分的Imprint的分类器来进行估计。首先对无标注数据的每个样本进行数据增强,产生M个增强版本,这样就得到数据集,将这M个版本的样本分别输入到相同的分类器中,将会产生M个不同的预测,取这M个预测值的平均值,见公式4。然后做一个sharpen操作(T=0.5),来最小化未标注数据的熵,sharpen之后的结果将作为最终的估计值,见公式5。
优化目标包括两个部分,一部分是交叉熵损失,一部分是自洽正则化的损失,见公式6。
公式中表示新类分类器,它是用于对无标注数据做出预测的。MixMatch方法采用了Mixup数据增强的方法,即构造混合样本和混合标签。首先将和进行合并(这里的合并应该是在axis=0方向的合并),然后做一个shuffle操作,见公式7,将得到的结果称为,然后将这个划分为两个部分,见公式8。这样的得到了两个增强的数据集和。其中是将数据集和的前个样本混合得到的。是将和的剩余的个样本混合得到的。因此公式6的标签应该是混合标签。但是公式6中的第二部分为什么有个N。
总览
1、采用基类数据集预训练一个特征提取器,这个特征提取器用于提取新类样本的特征,取新类样本特征的平均值来imprint新类分类器的权重。 2、将带标注数据的样本和无标注数据的样本进行合并,做一个shuffle,构成了一个新的集合。无标注样本的标签可以通过imprint过的分类器获得。 3、将带标注数据集和中前个样本进行MixUp操作得到数据集,是将和的剩余的个样本混合得到的。 4、利用Imprint的分类器在计算交叉熵损失,在计算自洽正则化损失。有了损失就可以计算梯度,然后进行反向传播更新模型参数。
附录
MixMatch是怎么样的过程 zhuanlan.zhihu.com/p/66281890