MCLRec:元优化对比学习用于序列推荐

468 阅读5分钟

论文题目:Meta--optimized Contrastive Learning for Sequential Recommendation[1]^{[1]}

论文来源:SIGIR2023

code: github

一、Motivation

  • 推荐面临数据稀疏和噪音的问题,引入对比学习可以很好的缓解这一问题,但是对于如何构建对比学习的正负样本,先前的工作主要依赖于随机数据增强[2]^{[2]},随机数据增强一方面可能会因为改变原始序列的语义信息从而使得随机增强得到的两个视图的分布不一致进而引入噪音损害模型的推荐效果,另外其增强方式很难泛化,针对不同的数据集需要使用不同的方式,并且由于随机增强带来的不确定性,导致其需要更大的batch size才能取得不错的效果。针对上述问题,我们做了以下工作:
    • 首先我们针对每个随机增强的视图引入了一个可学习的模型增强板块,这个板块主要作用是通过生成具有更多丰富信息的特征用于对比学习,一方面缓解随机增强带来的噪音问题,另外一方面通过提取更有用的对比特征减少模型对batch size的依赖;
    • 另外因为引入了新的参数,我们进一步提出了一种元优化策略,也即分两步更新编码器和增强器,这样会使得模型的效果更好,因为这缓解了多任务学习带来的gap问题[3]^{[3]};
    • 为了避免对比视图过多带来坍塌问题(生成的视图本身就特别相似,这样模型学到的表示特别有限),我们加入了一个正则;

二、Model

image.png

先前的对比学习框架如左图所示,其只是依赖于随机增强来构造正样本,这一方面会因为改变原始序列的内在含义从而引入噪音(使得两个随机生成的视图表示与原始视图表示相差很大),另外其很难泛化,每一个数据集需要单独设置一种合适的数据增强方式;右边的图是我们提出的模型框架,相比左图,我们主要引入两个新的板块(可学习的模型增强板块和元优化策略),我们首先会将随机生成的视图h~1\tilde{h}^{1}h~2\tilde{h}^{2}放入增强器得到新的视图z~1\tilde{z}^{1}z~2\tilde{z}^{2},然后引入了交叉对比学习来计算对比损失,损失函数如下所示:

Lrec=1y^[g]+log(iIexp(y^[i]))),\mathcal{L}_{rec}=-1*\hat{\mathbf{y}}[g]+\log(\sum_{i\in \mathcal{I}}\exp(\hat{\mathbf{y}}[i]))),

Lcon(x1,x2)=loges(x1,x2)es(x1,x2)+xneges(x1,x)loges(x2,x1)es(x2,x1)+xneges(x2,x),\mathcal{L}_{con}(\mathbf{x}^{1},\mathbf{x}^{2})=-\log\frac{e^{s(\mathbf{x}^{1}, \mathbf{x}^{2})}}{e^{s(\mathbf{x}^{1}, \mathbf{x}^{2})}+\underset{\mathbf{x} \in neg}{\sum}e^{s(\mathbf{x}^{1},\mathbf{x})}} -\log\frac{e^{s(\mathbf{x}^{2}, \mathbf{x}^{1})}}{e^{s(\mathbf{x}^{2}, \mathbf{x}^{1})}+\underset{\mathbf{x} \in neg}{\sum}e^{s(\mathbf{x}^{2},\mathbf{x})}},

Lcl1=Lcon(h~1,h~2),\mathcal{L}_{cl1}=\mathcal{L}_{con}(\tilde{\mathbf{h}}^{1},\tilde{\mathbf{h}}^{2}),

Lcl2=Lcon(z~1,z~2)+Lcon(h~1,z~2)+Lcon(h~2,z~1).\mathcal{L}_{cl2}=\mathcal{L}_{con}(\tilde{\mathbf{z}}^{1},\tilde{\mathbf{z}}^{2})+\mathcal{L}_{con}(\tilde{\mathbf{h}}^{1},\tilde{\mathbf{z}}^{2})+\mathcal{L}_{con}(\tilde{\mathbf{h}}^{2},\tilde{\mathbf{z}}^{1}).

R=1σ+([σ+omin]+)+1σ([omaxσ]+)\mathcal{R}=\frac{1}{|\sigma^{+}|}\sum([\sigma^{+}-o_{min}]_{+})+\frac{1}{|\sigma^{-}|}\sum([o_{max}-\sigma^{-}]_{+})

其中Lrec\mathcal{L}_{rec}是推荐损失,使用的交叉熵损失函数,Lcon\mathcal{L}_{con}是infoNCE4^{4}损失函数,Lcl1\mathcal{L}_{cl1}Lcl2\mathcal{L}_{cl2}分别为原始的随机数据增强对损失和可学习模型增强对比损失,R\mathcal{R}为我们提出的正则,因为引入了新的参数,我们使用了元优化的策略,更新算法如下表所示:

image.png

第一步,我们会冻结增强器的参数,使用三个损失采用多任务联合学习的形式更新编码器; 第二步,我们会冻结编码器,并且使用更新参数后的编码器重新对视图进行编码,然后使用Lcl2\mathcal{L}_{cl2}来更新增强器;

三、Data&Experiments

主要使用了Yelp[5]^{[5]}和Amazon[6]^{[6]}的数据集。

image.png

四、Performance

image.png

从上表可以看出,MCLRec在多个数据集上都取得了SOTA的实验效果,并且其较于baseline有明显的提升。

五、Ablation Study

image.png

首先做了所有部件的实验,从上表可以看出每一个部件对于模型都是有用的(因为MCLRec的效果是最好的),另外基于模型增强得到的对比损失部分明显要比基于随机数据增强得到的板块的损失更有效(去掉之后效果下降更多);另外我们发现使用共享的参数(两个增强器使用共享参数也会带来模型效果的下降),这可能是因为两个增强器生成的视图太相似导致模型效果下降;

image.png

为了进一步分析正则对模型的影响,我们对于使用正则和不使用正则得到的模型增强视图进行了可视化,从上图中,可以明显的看到去掉正则之后生成的视图表示中正负样本的距离特别的远,在投影空间中特别的离散,这可能是对比学习导致增强器生成了坍塌的视图表示(两个正样本距离之间也特别的远);另外引入正则得到的视图表示就更加正常,互为正样本的两个表示之间比较近,互为负样本的两个视图表示之间比较远;

image.png

为了分析元优化策略对模型的影响,我们对于使用该优化策略和直接使用联合学习进行了实验,对于随机增强的视图表示进行了可视化实验,从上图中可以明显看到在使用联合学习生成的表示中,随机数据增强得到的视图表示特别的靠近(正样本与负样本之间的距离也特别近),这可能是因为联合学习导致编码器学到了坍塌的表示,导致生成的视图特别的相似,因为使用联合学习的方法来同时训练编码器和增强器,这会导致模型出现震动现象。引入元优化策略很好地避免了这一情况,让两个模块的参数分别的更新,这缓解了多任务之间gap带来的消极影响;

六、Conclusion

针对我们提出的模型,我们还做了很多鲁棒性实验(对批量大小的依赖、对于稀疏数据、对于噪音数据、对于参数敏感性),都证明了我们方法的有效性。

七、References

[1] Qin X, Yuan H, Zhao P, et al. Meta-optimized Contrastive Learning for Sequential Recommendation[J]. arXiv preprint arXiv:2304.07763, 2023.

[2] Xu Xie, Fei Sun, Zhaoyang Liu, Jinyang Gao, Bolin Ding, and Bin Cui. Contrastive pre-training for sequential recommendation. arXiv preprint arXiv:2010.14395,2020.

[3] Zhuoliang Kang, Kristen Grauman, and Fei Sha. Learning with whom to share in multi-task feature learning. In ICML, pages 521–528, 2011.

[4] Ting Chen, Simon Kornblith, Mohammad Norouzi, and Geoffrey E. Hinton. A simple framework for contrastive learning of visual representations. In ICML, pages 1597–1607, 2020.

[5] Yelp

[6] Julian McAuley, Christopher Targett, Qinfeng Shi, and Anton Van Den Hengel.2015. Image-based recommendations on styles and substitutes. In SIGIR. 43–52.