《LEARNING TO PROPAGATE LABELS· TRANSDUCTIVE....》

190 阅读5分钟

title: 《LEARNING TO PROPAGATE LABELS· TRANSDUCTIVE PROPAGATION NETWORK FOR FEW-SHOT LEARNING》论文阅读

tags: 小样本学习

summary: 文章提出了用于标签传播的转导传播网络,有四个部分:特征提取,图模型构建,标签传播,损失计算,采用端到端的方式来训练这个网络

问题定义

本文仍然采用episode训练策略。给定一个包含较多的标注数据的训练集,我们来训练一个分类器,它能够较好的识别新类中的样本,并且新类中的带标注样本是很少的。

同样的,从训练集中采样N个类,每个类采样K个样本,构成了一个N-way K-shot的场景。这部分数据作为support set(S\mathcal{S})。再从这N个类中采样,构成query set(Q\mathcal{Q})。K通常是很小的,也就是标注数据非常少。这导致了很难获得一个好的分类器。作者的想法就是,利用整个query set来进行预测,而不仅仅是每个样本单独的使用。这可以缓解数据少的问题,以及增强泛化性能。

转导传播网络

转导传播网络包括四个部分:特征提取,图模型构建,标签传播,损失计算 特征提取部分是由一个卷积神经网络来实现的。标签传播就是将support set的标签传播到query set里面。损失计算使用的是交叉熵损失,计算的是query set里面实际标签和传播的标签之间的差别。

特征提取

利用一个卷积神经网络fφf_{\varphi}来进行特征提取,fφf_{\varphi}包括四个卷积块。每个卷积块包括2维的卷积层,使用的是3×3的kernel,共64个;然后跟着Batch-normalization层,relu激活函数和2×2的max-pooling层。φ\varphi这个参数根据最终的预测损失进行调整。

图模型构建

图的结点代表样本,边表示有联系,判断两个结点之间是否有边,使用的高斯相似度函数(公式1)。通过该函数计算出结点之间的相似度,然后选择k个最相近的样本建立边。构建的图称为k-最近邻图。

Wij=exp(d(xi,xj)2σ2)(1)W_{ij}=exp(-\frac{d(x_i, x_j)}{2\sigma ^2})\tag{1}

公式1中的d(xi,xj)d(x_i, x_j)是一个距离衡量函数,用于衡量这两个样本之间的距离,关于这个函数的细节还得参考原论文。

σ\sigma是一个超参数,作者提到并没有很好的调整这个参数的规则,因此作者构建了一个简单的卷积神经网络gϕ()g_{\phi}(·)来计算这个参数。实际上,这就相当于不进行人为的调整σ\sigma参数,而是根据最终预测结果的损失,让gϕg_{\phi}调整自身的参数ϕ\phi来使得其计算出的σ\sigma能够优化最终的预测结果。对每个样本都会计算自己的σ\sigma参数。

卷积网络gϕg_{\phi} 由两个卷积块和两个全连接层组成。第一个卷积块包括64个3×3的filter + batch normalization + ReLU activation + 2×2 max pooling。第二个卷积块包括1个3×3的filter + batch normalization + ReLU activation + 2×2 max pooling。第一个全连接层包括8个神经元,第二个全连接层包括1个神经元。

标准化处理 有了σ\sigma就可以计算边权矩阵WW了,作者还用标准化的图拉普拉斯算子处理了这个矩阵,如公式2:

S=D1/2WD1/2(2)S=D^{-1/2}WD^{-1/2}\tag{2}

标签传播

F\mathcal{F}表示100*100的非负矩阵的集合。YFY\in \mathcal{F}表示标签矩阵,Yij=1Y_{ij}=1表示样本xix_i来自support set并且属于第j个类。其他情况记为Yij=0Y_{ij}=0。从YY开始迭代公式3,就可以根据构建的图结构进行标签传播了。其中FtF_{t}表示第t个时间戳下的标签预测得分。公式4是公式3收敛时的情况,可以利用公式4直接进行计算预测的标签得分。

Ft+1=αSFt+(1α)Y(3)F_{t+1}=\alpha SF_t+(1-\alpha)Y\tag{3}
F=(IαS)1Y(4)F^{\star}=(I-\alpha S)^{-1}Y\tag{4}

损失计算

输入的样本来自support set和query set两个部分。预测结果在FF^{\star}的基础上使用了softmax激活函数。损失函数使用的是交叉熵损失函数,计算的是通过标签传播计算得到的标签和真实标签之间的差别。损失函数包括两个参数φ\varphiϕ\phi。其中φ\varphi表示用于特征提取的卷积神经网络的参数,ϕ\phi是用于计算高斯相似度函数中的参数σ\sigma的卷积神经网络的参数。

总览

  1. 将support set和query set中的样本合并,记为XX,然后输入到卷积网络fφf_{\varphi},提取到特征fφ(X)f_{\varphi}(X)
  2. 将提取到的特征fφ(X)f_{\varphi}(X)输入到卷积网络gϕg_{\phi}得到每个样本的σ\sigma
  3. 根据这个σ\sigma计算出相似度矩阵WW,再用规范化的图拉普拉斯算子处理一下这个矩阵得到SS。选择最相似的k个结点建立边。
  4. 利用构建好的图进行标签传播,这样得到了query set上的预测的标签,然后计算损失。有了损失就可以计算梯度,然后反向传播,进行参数的更新。

附录

流形结构:空间中的点的集合,是一个集合概念。比如:二维空间的一个曲线就是二维空间的一维流形。

转导推理:通过观察特定的训练样本,进而预测特定的训练样本。可以认为它是从特殊到特殊

归纳推理:先从训练样本中学习得到通过的规则,再利用规则判断测试样本。归纳目的在于找出一个一般性的规律,因此这里肯定涉及到了泛化。

端到端的学习:从输入端(输入数据)到输出端会得到一个预测结果,与真实结果相比较会得到一个误差,这个误差会在模型中的每一层传递(反向传播),每一层的表示都会根据这个误差来做调整,直到模型收敛或达到预期的效果才结束,这是端到端的。在本文中,我认为体现最明显的地方就是ϕ\phiφ\varphi这两个参数的学习。