2022CVPR行人重识别文章精读之10.Graph Sampling Based Deep Metric Learning for Generalizable

361 阅读5分钟

开启掘金成长之旅!这是我参与「掘金日新计划 · 12 月更文挑战」的第20天,点击查看活动详情

论文--[1]Shengcai Liao, Ling Shao.Graph Sampling Based Deep Metric Learning for Generalizable Person Re-Identification. In CVPR,2022

贡献

在分类或度量学习中,涉及类参数或者signatures对于大规模的行人重识别训练都是无效的。所以要使用小批量样本之间的成对深度度量关系。

PK采样器--首先随机选择P个类,然后对每个类随机采样K个图像,以构建大小为B=P×K的小批量。由于这是随机执行的,所以小批量内的采样实例均匀分布在整个数据集,对于深度度量学习而言,可能不具有很大的信息量和很高的效率。于是提出了三元组难样本挖掘,但因为PK采样器获得的小批量不考虑样本关系信息,所以还是受到完全随机PK采样器的限制

将硬实例挖掘转移到数据采样阶段,因此提出图采样(GS)--方法基本思想是在每个epoch开始时为所有类构建最近邻关系图。然后,通过随机选择一个类作为锚点及其前k个最近的相邻类来执行小批量采样,每个类具有相同的k个实例。这样,采样的小批量内的实例大多彼此相似,从而为判别学习提供了信息丰富且具有挑战性的实例。从人脸识别损失函数研究中可知,关注边界(硬)示例有助于提高学习模型的判别能力,并有助于生成概括远远超出训练数据的紧凑表示。GS采样器在关注最近的相邻类方面有着相似的想法,因此有可能提高学习模型的辨别和泛化能力。

 

深度度量学习

分类--损失函数设计(成对损失函数,分类或识别损失,三元组损失),深度特征匹配

针对跨域行人重识别,还提出了元学习

深度行人重识别通常有两种学习方法--第一种是基于分类的方法,使用识别损失;第二种是基于三元组损失的方法

本文提出只考虑小批量样本之间的成对匹配,并去除其类内存


方法--图采样

图片.png

图1。两种不同的采样方法:(a) PK采样器;及(b)建议的GS取样器。不同的形状表示不同的类别,不同的颜色表示不同的批次。GS为所有类构造一个图,并且总是对最近的相邻类进行采样。

取样方法本身需要改进,以便为小批量提供有信息量的样品。需要考虑类之间的关系,而不是使用完全随机抽样。因此,在每个epoch的开始为所有类构造一个图,并且总是在一个小批处理中对最近的相邻类进行抽样,以实现判别学习。称之为概念图抽样(GS)

在每个epoch开始,利用最新学习的模型来评估类之间的距离或相似性,然后为所有类构造一个图。这样,类之间的关系可以用于信息抽样。具体来说,每个类随机选择一个图像来构建一个小的子数据集。然后提取具有当前网络的特征嵌入,记为X∈RC×d,其中C为待训练的类总数,d为特征维数。

接下来,计算所有选定样本之间的成对距离,例如通过QAConv。因此,可以得到所有类的距离矩阵dist∈RC×C。

然后,对于每个类c,可以检索到最靠近的P−1个类,表示为N (c) = {xi|i = 1,2,…P−1},其中P是每个迷你批中要采样的类的数量。据此,可以构造图G = (V, E),其中V = {c|c = 1,2,…C}表示顶点,每个类为一个节点,表示边。

最后,对于小批量抽样,对于每个类c作为锚点,在g中检索其所有连接的类,然后与锚点类c一起,得到一个集合A = ,其中| A | = P。接下来,对于A中的每个类,每个类随机抽取K个实例,生成一小批B = P × K个样本用于训练。GS采样器的伪代码如下

注意,与其他小批量采样方法不同,GS采样器每个epoch的小批量或迭代次数总是C,这与参数B, P和k无关。尽管如此,参数B仍然影响每个小批量的计算负荷。

此外,有人可能会担心GS取样器的计算成本会很高。然而,请注意,首先,每个类只有一张图像被随机采样用于图的构造;其次,上述计算每个epoch只执行一次。在实践中发现,与主流欧几里得距离相比,QAConv的GS采样器已经是一个沉重的匹配器,但数千个身份只需要几十秒。

图片.png

损失

利用GS采样器提供的小批量,应用QAConv计算每对图像之间的相似度,并在小批量中提出了一个基于三元组的排序学习问题。单独计算了batch OHEM三元组损失用于度量学习:

图片.png

实验

ResNet50被用作骨干,并附加IBN-b层。使用layer3特征图,加上128通道的颈卷积作为最终的特征图。输入图像大小为384 × 128 。应用了几种常用的数据增强方法,包括随机裁剪、翻转、遮挡和颜色抖动。批大小设置为64。采用SGD优化器对网络进行训练,骨干网学习率为0.0005,新增层学习率为0.005。最长学习周期为60个。当初始损失减少为0.7倍时,学习率衰减为0.1,并且在已经学习的周期的另外一半之后触发早期停止。

梯度裁剪应用于T = 8。利用PyTorch中的自动混合精度(AMP)加速训练。当进一步应用所提出的GS采样器(用QAConv-GS表示)时,使用难三元组损失(m=16)。GS的默认参数为B=64, K=2。


数据集:Market1501,CUHK03,MSMT17,RandPerson

图片.png