title: 《LEARNING TO PROPAGATE LABELS· TRANSDUCTIVE PROPAGATION NETWORK FOR FEW-SHOT LEARNING》论文阅读
tags: 小样本学习
summary: 文章提出了用于标签传播的转导传播网络,有四个部分:特征提取,图模型构建,标签传播,损失计算,采用端到端的方式来训练这个网络
问题定义
本文仍然采用episode训练策略。给定一个包含较多的标注数据的训练集,我们来训练一个分类器,它能够较好的识别新类中的样本,并且新类中的带标注样本是很少的。
同样的,从训练集中采样N个类,每个类采样K个样本,构成了一个N-way K-shot的场景。这部分数据作为support set()。再从这N个类中采样,构成query set()。K通常是很小的,也就是标注数据非常少。这导致了很难获得一个好的分类器。作者的想法就是,利用整个query set来进行预测,而不仅仅是每个样本单独的使用。这可以缓解数据少的问题,以及增强泛化性能。
转导传播网络
转导传播网络包括四个部分:特征提取,图模型构建,标签传播,损失计算 特征提取部分是由一个卷积神经网络来实现的。标签传播就是将support set的标签传播到query set里面。损失计算使用的是交叉熵损失,计算的是query set里面实际标签和传播的标签之间的差别。
特征提取
利用一个卷积神经网络来进行特征提取,包括四个卷积块。每个卷积块包括2维的卷积层,使用的是3×3的kernel,共64个;然后跟着Batch-normalization层,relu激活函数和2×2的max-pooling层。这个参数根据最终的预测损失进行调整。
图模型构建
图的结点代表样本,边表示有联系,判断两个结点之间是否有边,使用的高斯相似度函数(公式1)。通过该函数计算出结点之间的相似度,然后选择k个最相近的样本建立边。构建的图称为k-最近邻图。
公式1中的是一个距离衡量函数,用于衡量这两个样本之间的距离,关于这个函数的细节还得参考原论文。
是一个超参数,作者提到并没有很好的调整这个参数的规则,因此作者构建了一个简单的卷积神经网络来计算这个参数。实际上,这就相当于不进行人为的调整参数,而是根据最终预测结果的损失,让调整自身的参数来使得其计算出的能够优化最终的预测结果。对每个样本都会计算自己的参数。
卷积网络 由两个卷积块和两个全连接层组成。第一个卷积块包括64个3×3的filter + batch normalization + ReLU activation + 2×2 max pooling。第二个卷积块包括1个3×3的filter + batch normalization + ReLU activation + 2×2 max pooling。第一个全连接层包括8个神经元,第二个全连接层包括1个神经元。
标准化处理 有了就可以计算边权矩阵了,作者还用标准化的图拉普拉斯算子处理了这个矩阵,如公式2:
标签传播
用表示100*100的非负矩阵的集合。表示标签矩阵,表示样本来自support set并且属于第j个类。其他情况记为。从开始迭代公式3,就可以根据构建的图结构进行标签传播了。其中表示第t个时间戳下的标签预测得分。公式4是公式3收敛时的情况,可以利用公式4直接进行计算预测的标签得分。
损失计算
输入的样本来自support set和query set两个部分。预测结果在的基础上使用了softmax激活函数。损失函数使用的是交叉熵损失函数,计算的是通过标签传播计算得到的标签和真实标签之间的差别。损失函数包括两个参数和。其中表示用于特征提取的卷积神经网络的参数,是用于计算高斯相似度函数中的参数的卷积神经网络的参数。
总览
- 将support set和query set中的样本合并,记为,然后输入到卷积网络,提取到特征。
- 将提取到的特征输入到卷积网络得到每个样本的。
- 根据这个计算出相似度矩阵,再用规范化的图拉普拉斯算子处理一下这个矩阵得到。选择最相似的k个结点建立边。
- 利用构建好的图进行标签传播,这样得到了query set上的预测的标签,然后计算损失。有了损失就可以计算梯度,然后反向传播,进行参数的更新。
附录
流形结构:空间中的点的集合,是一个集合概念。比如:二维空间的一个曲线就是二维空间的一维流形。
转导推理:通过观察特定的训练样本,进而预测特定的训练样本。可以认为它是从特殊到特殊
归纳推理:先从训练样本中学习得到通过的规则,再利用规则判断测试样本。归纳目的在于找出一个一般性的规律,因此这里肯定涉及到了泛化。
端到端的学习:从输入端(输入数据)到输出端会得到一个预测结果,与真实结果相比较会得到一个误差,这个误差会在模型中的每一层传递(反向传播),每一层的表示都会根据这个误差来做调整,直到模型收敛或达到预期的效果才结束,这是端到端的。在本文中,我认为体现最明显的地方就是和这两个参数的学习。