GIT:斯坦福大学提出应对复杂变换的不变性提升方法 | ICLR 2022

959 阅读10分钟

论文对长尾数据集中的复杂变换不变性进行了研究,发现不变性在很大程度上取决于类别的图片数量,实际上分类器并不能将从大类中学习到的不变性转移到小类中。为此,论文提出了GIT生成模型,从数据集中学习到类无关的复杂变换,从而在训练时对小类进行有效增强,整体效果不错

来源:晓飞的算法工程笔记 公众号

论文: Do Deep Networks Transfer Invariances Across Classes?

Introduction


  优秀的泛化能力需要模型具备忽略不相关细节的能力,比如分类器应该对图像的目标是猫还是狗进行响应,而不是背景或光照条件。换句话说,泛化能力需要包含对复杂但不影响预测结果的变换的不变性。在给定足够多的不同图片的情况下,比如训练数据集包含在大量不同背景下的猫和狗的图像,深度神经网络的确可以学习到不变性。但如果狗类的所有训练图片都是草地背景,那分类器很可能会误判房子背景中的狗为猫,这种情况往往就是不平衡数据集存在的问题。
类不平衡在实践中很常见,许多现实世界的数据集遵循长尾分布,除几个头部类有很多图片外,而其余的每个尾部类都有很少的图片。因此,即使长尾数据集中图片总量很大,分类器也可能难以学习尾部类的不变性。虽然常用的数据增强可以通过增加尾部类中的图片数量和多样性来解决这个问题,但这种策略并不能用于模仿复杂变换,如更换图片背景。需要注意的是,像照明变化之类的许多复杂变换是类别无关的,能够类似地应用于任何类别的图片。理想情况下,经过训练的模型应该能够自动将这些不变性转为类无关的不变性,兼容尾部类的预测。
论文通过实验观察分类器跨类迁移学习到的不变性的能力,从结果中发现即使经过过采样等平衡策略后,神经网络在不同类别之间传递学习到的不变性也很差。例如,在一个长尾数据集上,每个图片都是随机均匀旋转的,分类器往往对来自头部类的图片保持旋转不变,而对来自尾部类的图片则不保持旋转不变。
为此,论文提出了一种更有效地跨类传递不变性的简单方法。首先训练一个input conditioned但与类无关的生成模型,该模型用于捕获数据集的复杂变换,隐藏了类信息以便鼓励类之间的变换转移。然后使用这个生成模型来转换训练输入,类似于学习数据增强来训练分类器。论文通过实验证明,由于尾部类的不变性得到显著提升,整体分类器对复杂变换更具不变性,从而有更好的测试准确率。

Measuring Invariance Transfer In Class-Imbalanced Datasets


  论文先对不平衡场景中的不变性进行介绍,随后定义一个用于度量不变性的指标,最后再分析不变性与类别大小之间的关系。

Setup:Classification,Imbalance,and Invariances

  定义输入(x,y)(x,y),标签yy属于{1,,C}\{1,\cdots,C\}CC为类别数。定义训练后的模型的权值ww,用于预测条件概率P~w(y=jx)\tilde{P}_w(y=j|x),分类器将选择概率最大的类别jj作为输出。给定训练集{(x(i),y(i))}i=1NPtrain\{(x^{(i)}, y^{(i)})\}^N_{i=1}\sim \mathbb{P}_{train},通过经验风险最小化(ERM)来最小化训练样本的平均损失。但在不平衡场景下,由于{y(i)}\{y^{(i)}\}的分布不是均匀的,导致ERM在少数类别上表现不佳。
在现实场景中,最理想的是模型在所有类别上都表现得不错。为此,论文采用类别平衡的指标来评价分类器,相当于测试分布Ptest\mathbb{P}_{test}yy上是均匀的。
为了分析不变性,论文假设xx的复杂变换分布为T(x)T(\cdot|x)。对于不影响标签的复杂变换,论文希望分类器是不变的,即预测的概率不会改变:

Measuring Learned Invariacnes

  为了度量分类器学习不变性的程度,论文定义了原输入和变换输入之间的期望KL散度(eKLD):

  这是一个非负数,eKLD越低代表不变性程度就越高,对TT完全不变的分类器的eKLD为0。如果有办法采样xT(x)x^{'}\sim T(\cdot|x),就能计算训练后的分类器的eKLD。此外,为了研究不变性与类图片数量的关系,可以通过分别计算类特定的eKLD进行分析,即将公式2的xx限定为类别jj所属。
计算eKLD的难点在于复杂变化分布TT的获取。对于大多数现实世界的数据集而言,其复杂变化分布是不可知的。为此,论文通过选定复杂分布来生成数据集,如RotMNIST数据集。与数据增强不同,这种生成方式是通过变换对数据集进行扩充,而不是在训练过程对同一图片应用多个随机采样的变换。
论文以Kuzushiji-49作为基础,用三种不同的复杂变换生成了三个不同的数据集:图片旋转(K49-ROT-LT)、不同背景强度(K49-BG-LT)和图像膨胀或侵蚀(K49-DIL-LT)。为了使数据集具有长尾分布(LT),先从大到小随机选择类别,然后有选择地减少类别的图片数直到数量分布符合参数为2.0的Zipf定律,同时强制最少的类为5张图片。重复以上操作30次,构造30个不同的长尾数据集。每个长尾数据集有7864张图片,最多的类有4828张图片,最小的类有5张图片,而测试集则保持原先的不变。

  训练方面,采用标准ERM和CE+DRS两种方法,其中CE+DRS基于交叉熵损失进行延迟的类平衡重采样。DRS在开始阶段跟ERM一样随机采样,随后再切换为类平衡采样进行训练。论文为每个训练集进行两种分类器的训练,随后计算每个分类器每个类别的eKLD指标。结果如图1所示,可以看到两个现象:

  • 在不同变化数据集上,不变性随着类图片数减少都降低了。这表明虽然复杂变换是类无关的,但在不平衡数据集上,模型无法在类之间传递学习到的不变性。
  • 对于图片数量相同的类,使用CE+DRS训练的分类器往往会有较低的eKLD,即更好的不变性。但从曲线上看,DRS仍有较大的提升空间,还没达到类别之间一致的不变性。

Trasnferring Invariances with Generative Models


  从前面的分析可以看到,长尾数据集的尾部类对复杂变换的不变性较差。下面将介绍如何通过生成式不变性变换(GIT)来显式学习数据集中的复杂变换分布T(x)T(\cdot|x),进而在类间转移不变性。

Learning Nuisance Transformations from Data

  如果有数据集实际相关的复杂变换的方法,可以直接将其用作数据增强来加强所有类的不变性,但在实践中很少出现这种情况。于是论文提出GIT,通过训练input conditioned的生成模型T~(x)\tilde{T}(\cdot|x)来近似真实的复杂变换分布T(x)T(\cdot|x)

  论文参考了多模态图像转换模型MUNIT来构造生成模型,该类模型能够从数据中学习到多种复杂变换,然后对输入进行变换生成不同的输出。论文对MUNIT进行了少量修改,使其能够学习单数据集图片之间的变换,而不是两个不同域数据集之间的变换。从图2的生成结果来看,生成模型能够很好地捕捉数据集中的复杂变换,即使是尾部类也有不错的效果。需要注意的是,MUNIT是非必须的,也可以尝试其它可能更好的方法。
在训练好生成模型后,使用GIT作为真实复杂变换的代理来为分类器进行数据增强,希望能够提高尾部类对复杂变换的不变性。给定训练输入{(x(i),y(i))}i=1B\{(x^{(i)}, y^{(i)})\}^{|B|}_{i=1},变换输入x~(i)T~(x(i))\tilde{x}^{(i)}\gets \tilde{T}(\cdot|x^{(i)}),保持标签不变。这样的变换能够提高分类器在训练期间的输入多样性,特别是对于尾部类。需要注意的是,batch可以搭配任意的采样方法(Batch Sampler),比如类平衡采样器。此外,还可以有选择地进行增强,避免由于生成模型的缺陷损害性能的可能性,比如对数量足够且不变性已经很好的头部类不进行增强。

  在训练中,论文设置阈值KK,仅图片数量少于KK的类进行数据增强。此外,仅对每个batch的pp比例进行增强。pp一般取0.5,而KK根据数据集可以设为20-500,整体逻辑如算法1所示。

GIT Improves Invariance on Smaller Classes

  论文基于算法1进行了实验,将Batch Sampler设为延迟重采样(DRS),Update Classifier使用交叉熵梯度更新,整体模型标记为CE+DRS+GIT(allclasses)CE+DRS+GIT(all classes)。all classes表示禁用阈值KK,仅对K49数据集使用。作为对比,Oracle则是用于构造生成数据集的真实变换。从图3的对比结果可以看到,GIT能够有效地增强尾部类的不变性,但同时也损害了图片充裕的头部类的不变性,这表明了阈值KK的必要性。

Experiment


  不同训练策略搭配GIT的效果对比。

  在GTSRB和CIFAR数据集上的变换输出。

  CIFAR-10上每个类的准确率。

  对比实验,包括阈值KK对性能的影响,GTSRB-LT, CIFAR-10 LT和CIFAR-100 LT分别取25、500和100。这里的最好性能貌似都比RandAugment差点,有可能是因为论文还没对实验进行调参,而是直接复用了RandAugment的实验参数。这里比较好奇的是,如果在训练生成模型的时候加上RandAugment,说不定性能会更好。

Conclusion


  论文对长尾数据集中的复杂变换不变性进行了研究,发现不变性在很大程度上取决于类别的图片数量,实际上分类器并不能将从大类中学习到的不变性转移到小类中。为此,论文提出了GIT生成模型,从数据集中学习到类无关的复杂变换,从而在训练时对小类进行有效增强,整体效果不错。



如果本文对你有帮助,复杂点个赞或在看呗~
更多内容请关注 微信公众号【晓飞的算法工程笔记】