元学习Meta learning深入理解

1,671 阅读11分钟

目录

基本理解

元学习与传统的机器学习不同在哪里?

基本思想

MAML

MAML与pre-training有什么区别呢?

1. 损失函数不同

 2. 优化思想不同

MAML的优点及特点

MAML工作机理

 MAML应用:Toy Example

Reptile


基本理解

Meta Learning,翻译为元学习,也可以认为是learn to learn。

元学习与传统的机器学习不同在哪里?

知乎博主“南有乔木”在理解 元学习与传统的机器学习 这里举了个通俗易懂的例子,拿来给大家分享:

把训练算法类比成学生在学校的学习,传统的机器学习任务对应的是在每个科目上分别训练一个模型,而元学习是提高学生整体的学习能力,学会学习。

学校中 ,有的学生各科成绩都好,有的学生却存在偏科现象。

  • 各科成绩都好,说明学生“元学习”能力强,学会了如何学习,可以迅速适应不同科目的学习任务。
  • 偏科学生“元学习”能力相对较弱,只能某一科学习成绩好,换门科就不行了。不会举一反三,触类旁通。

现在经常使用的深度神经网络都是“偏科生”,分类和回归对应的网络模型完全不同,即使同样是分类任务,把人脸识别的网络架构用在分类ImageNet数据上,也未必能达到很高的准确率。

 

还有一个不同点:

  • 传统的深度学习方法都是从头开始学习(训练),即learning from scratch,对算力和时间都是更大的消耗和考验。
  • 元学习强调从不同的若干小任务小样本来学习一个对未知样本未知类别都有好的判别和泛化能力的模型

基本思想

写在前面:图片均来自李宏毅老师教学视频

图1

图1

对图1的解释:

Meta learning又称为learn to learn,是说让机器“学会学习”,拥有学习的能力。

元学习的训练样本和测试样本都是基于任务的。通过不同类型的任务训练模型,更新模型参数,掌握学习技巧,然后举一反三,更好地学习其他的任务。比如任务1是语音识别,任务2是 图像识别,···,任务100是文本分类,任务101与 前面100个任务类型均不同,训练任务即为这100个不相同的任务,测试任务为第101个任务。

图 2

图2

对图2的解释:

在机器学习中,训练样本中的训练集称为train set,测试集称为test set。元学习广泛应用于小样本学习中,在元学习中,训练样本中的训练集称为support set,训练样本中的测试集叫做query set。

注意 : 在机器学习中,只有一个大样本数据集,将这个一个大数据集分成了两部分,称为train set和test set;

但是在元学习中,不止一个数据集,有多少个不同的任务,就有多少个数据集,然后每个数据集又分成两部分,分别称为support set和query set。

这里没有考虑验证集。

图3

 对图3的解释:

 

图3为传统深度学习的操作方式,即:

  1. 定义一个网络架构;
  2. 初始化参数
  3. 通过自己选择的优化器更新参数;
  4. 通过两次epoch进行参数的更新;
  5. 得到网络最终的输出。

元学习与传统深度学习的联系在哪里?

图3中红色方框中的东西都是人为设计定义的,即我们常说的“超参数”,而元学习的目标就是去自动学习或者说代替方框中的东西,不同的代替方式就发明出不同的元学习算法。

图4

对图4的解释:

图4简单介绍了元学习的原理。

在神经网络算法,都需定义一个损失函数来评价模型好坏,元学习的损失通过N个任务的测试损失相加得到。定义在第n个任务上的测试损失是 ,则对于N个任务来说,总的损失为 ,这就是元学习的优化目标。

假设有两个任务,Task1和Task2,通过训练任务1,得到任务1的损失函数l1,通过训练任务2,得到任务2的损失函数l2,然后将这两个任务的损失函数相加,得到整个训练任务的损失函数,即图4右上角的公式。

 

如果前文对元学习了解还不够,后面有更详细的解释:

Meta Learning 的算法有很多,有些高大上的算法可以针对不同的训练任务,输出不同的神经网络结构和超参数,例如 Neural Architecture Search (NAS) 和 AutoML。这些算法大多都相当复杂,我们普通人难以实现。另外一种比较容易实现的 Meta Learning 算法,就是本文要介绍的 MAML 和 Reptile,它们不改变深度神经网络的结构,只改变网络的初始化参数。

 

MAML

理解MAML算法的损失函数含义和推导过程,首先得与pre-training区分开来。

对图5的解释:

我们定义初始化参数为,其初始化参数为,定义在第n个测试任务上训练之后的模型参数为,于是MAML总的损失函数为 。

图5

MAML与pre-training有什么区别呢?

1. 损失函数不同

MAML的损失函数为 。

pre-training的损失函数是

直观上理解是MAML所评测的损失是在任务训练之后的测试loss,而pre-training是直接在原有基础上求损失没有经过训练。如图6所示。

图6​​​​​

 2. 优化思想不同

这里先分享一下我看到的对损失函数最恰当的描述:(zhuanlan.zhihu.com/p/72920138

损失函数的奥妙:初始化参数掌控全场,分任务参数各自为营

图7

 

图8

 如图7和图8所示:

上图中横坐标代表网络参数,纵坐标代表损失函数。浅绿和墨绿两条曲线代表两个 task 的损失函数随参数变化曲线。

假设模型参数的向量都是一维的,

model pre-training的初衷是寻找一个从一开始就让所有任务的损失之和处于最小状态,它并不保证所有任务都能训练到最好的,如上图所示,  即收敛到局部最优。从图7中看就是,loss值按照计算公式达到了最小值,但此时task2(浅绿)线只能收敛到左边的绿点处,即局部最小处,而从整体看来,全局最小处在的右边出现。

而MAML的初衷是找到一个不偏不倚的,使得不管是在任务1的loss曲线还是任务2的loss曲线上,都能下降到分别的全局最优。从图8中看就是,loss值按照计算公式到达了最小值​​,此时,task1可以收敛到左边绿点处,task2可以收敛到右边绿点处,二者均为全局最小值。

李宏毅老师在这里举了个很生动的比喻:他把MAML比作选择读博,即更在意的是学生的以后的发展潜力;而model pre-training就相当于选择毕业直接去大厂工作,马上就把所学技能兑现金钱,在意的是当下表现如何。如图9所示。

图9

MAML的优点及特点

如图10所示:MAML

  1. 计算速度快
  2. 所有的更新参数步骤都被限制在了一次,即one-step
  3. 在用这个算法时,即测试新任务的表现时可以更新更多次
  4. 适用于数据有限的情况

图10

MAML工作机理

 在介绍MAML的论文中,给出的算法如图11所示:

图11

 下面给出每步的详细解释:参考(zhuanlan.zhihu.com/p/57864886

  • Require1:task的分布,即随机抽取若干个task组成任务池
  • Require2:step size是学习率,MAML基于二重梯度,每次迭代包括两次参数更新的过程,所以有两个学习率可以调整。
  1. 随机初始化模型的参数
  2. 循环,可以理解为一轮迭代过程或一个epoch
  3. 随机对若干个task采样,形成一个batch。
  4. 对batch中的每一个task进行循环
  5. 对利用batch中的某一个task中的support set,计算每个参数的梯度。在N-way K-shot的设置下,这里的support set应该有NK个。(N-way K-shot意思是有N种不同的任务,每个任务有K个不同的样本)。
  6. 第一次梯度的更新。
  7. 结束第一次梯度更新
  8. 第二次梯度更新。这里用的样本是query set。步骤8结束后,模型结束在该batch中的训练,开始回到步骤3,继续采样下一个batch。

有一个对MAML过程更直观的图:

图12

对图12的解释为:

 MAML应用:Toy Example

该 toy example 的目标是拟合正弦曲线:  ,其中 a、b 都是随机数,每一组 a、b 对应一条正弦曲线,从该正弦曲线采样 K 个点,用它们的横纵坐标作为一组 task,横坐标为神经网络的输入,纵坐标为神经网络的输出。

我们希望通过在很多 task 上的学习,学到一组神经网络的初始化参数,再输入测试 task 的 K 个点时,经过快速学习,神经网络能够拟合测试 task 对应的正弦曲线。

图13

左侧是用常规的 fine-tune 算法初始化神经网络参数。我们观察发现,当把所有训练 task 的损失函数之和作为总损失函数,来直接更新网络参数,会导致无论测试 task 输入什么坐标,预测的曲线始终是 0 附近的曲线,因为 a 和 b 可以任意设置,所以所有可能的正弦函数加起来,它们的期望值为 0,因此为了获得所有训练 task 损失函数之和的 global minima,不论什么输入坐标,神经网络都将输出 0。

右侧是通过 MAML 训练的网络,MAML的初始化结果是绿色的线,和橘黄色的线有差异。但是随着finetuning的进行,结果与橘黄色的线更加接近。

 

针对前面介绍的MAML,提出一个问题:

在更新训练任务的网络时,只走了一步,然后更新meta网络。为什么是一步,可以是多步吗?

李宏毅老师的课程中提到:

  • 只更新一次,速度比较快;因为meta learning中,子任务有很多,都更新很多次,训练时间比较久。
  • MAML希望得到的初始化参数在新的任务中finetuning的时候效果好。如果只更新一次,就可以在新任务上获取很好的表现。把这件事情当成目标,可以使得meta网络参数训练是很好(目标与需求一致)。
  • 当初始化参数应用到具体的任务中时,也可以finetuning很多次。
  • Few-shot learning往往数据较少。

Reptile

Reptile与MAML类似,其算法图如下:

图14

Reptile 中,每更新一次  ,需要 sample 一个 batch 的 task(图中 batchsize=1),并在各个 task 上施加多次梯度下降,得到各个 task 对应的  。然后计算  和主任务的参数的差向量,作为更新  的方向。这样反复迭代,最终得到全局的初始化参数。

 其伪代码如下:

Reptile,每次sample出1个训练任务

 

Reptile,每次sample出1个batch训练任务

 

在Reptile中:

  • 训练任务的网络可以更新多次
  • reptile不再像MAML一样计算梯度(因此带来了工程性能的提升),而是直接用一个参数  乘以meta网络与训练任务的网络参数的差来更新meta网络参数
  • 从效果上来看,Reptile效果与MAML基本持平

 

以上为对元学习的深入理解,后续可能出MAML数学公式推导,感兴趣的读者留言~


参考资料

【1】zhuanlan.zhihu.com/p/72920138

【2】zhuanlan.zhihu.com/p/57864886

【3】zhuanlan.zhihu.com/p/108503451

【4】MAML论文arxiv.org/pdf/1703.03…

【5】 zhuanlan.zhihu.com/p/136975128