阅读 97

Learning To Rank算法和评价指标

排序学习是推荐、搜索、广告的核心方法,而LTR就是专门做排序任务的一个有监督的机器学习算法。所以,LTR仍然是传统的机器学习处理范式,构造特征,学习目标,训练模型,预测。LTR一般分为三种类型:PointWise,PairWise和ListWise。这三种算法并不是特定的算法,而是三种设计思路,主要区别体现在损失函数、标签标注方式和优化方法的不同。

1. PointWise

以搜索任务为例,PointWise只考虑当前Qeury与每个文档的绝对相关度,而没有考虑其他文档与Qeury的相关度。PW的方法通常将文档编码成特征向量,根据训练数据训练分类模型或者回归模型,在预测阶段,直接对文档进行打分,按照此得分排序就是搜索的结果。

处理逻辑如下图:

PointWise

  • 实施细节

训练数据格式为三元组:(qi,dj,yij)(q_i,d_j,y_{ij}),标签yijy_{ij}是2个数值,表示相关/不相关。 训练一个二分类模型或者回归模型直接拟合yijy_{ij}。 loss函数:分类模型Loss函数可以使用交叉熵,回归模型Loss函数可以使用均方误差(MSE)。 预测阶段;得分直接用作排序。

  • PointWise的问题
  1. PointWise只考虑query和单个文档document之间的相关性simq,dsim_{q,d}没有考虑候选文档之间的关系。既然我们追求的目标是对候选结果进行排序,其实是想计算相对得分,直接使用simq,dsim_{q,d}的大小来排序,往往没有那么准确。实际上simq,dsim_{q,d}只是准确度概率,而不是真正的相对顺序概率。
  2. PointWise没有考虑同一个query对应的文档之间的内部依赖性。这回导致如下问题:1.导致输入空间内的样本不是独立同步分(IID)的,违反了机器学习的基本假设。2.当不同query有不同数量的文档时,整体loss容易被那些有更多文档(训练数据)的query组所支配。
  3. 排序问题关注的是topk的准确率,所以loss函数的设置需要加入相对位置排序的信息。

2. PairWise

PairWise的基本思路是对样本进行两两比较,构建偏序文档对,从比较中学习顺序。正如在PointWise中分析的,对于一个查询来说,我们需要的是检索结果正确的顺序,而不是检索结果与query的相关得分。PairWise就是希望通过正确估计一对文档的顺序,而得到整体的正确顺序。比如一个正确的排序为:“A>B>C”,PairWise通过学习两两之间的关系“A>B”,“B>C”和“A>C”来推断“A>B>C”。

处理逻辑如下: 此处输入图片的描述

  • 实施细节

训练数据格式为(qi,di+,di)(q_i,d^+_i,d^-_i),是一个query的正例和负例。通常又被称为:(anchor,positive,negative)。 PairWise实际上是一种metric learning的思路来直接学习他们的相对距离,而不在乎实际的值。 loss函数:大概有两种, 1:输入pair对的Ranking Loss: L(r0,r1,y)=yd(r0r1)+(1y)max(0,margind(r0r1))L(r_0,r_1,y)=yd(r_0-r_1)+(1-y) max(0,margin-d(r_0-r_1)) 其中y的取值为0或者1。 2: 输入三元组的Triplet Loss或者Contrastive Loss: L(ra,rp,rn)=max(0,margin+d(ra,rp)d(ra,rn))L(r_a,r_p,r_n)=max(0,margin+d(r_a,r_p)-d(r_a,r_n)) 预测阶段;和PointWise一样,得分直接用作排序。

  • PairWise的问题
  1. 由于需要构造pair格式的数据集,数量可能是doc数量的n倍(依据不同的构造策略),而PointWise中存在的*“当不同query有不同数量的文档时,整体loss容易被那些有更多文档(训练数据)的query组所支配”*的问题依然没存在,甚至进一步扩大。
  2. PairWise相对于PointWise对于噪音数据更敏感,即一个错误标注将会导致多个pair的错误。
  3. PairWise仍然只是考虑一对doc的相对位置,损失函数还是没有考虑候选文档之间的关系。可以认为是PointWise的优化版,基本思路没有变化。
  4. 同样的,PairWise没有考虑同一个query对应的文档之间的内部依赖性。导致输入空间内的样本不是独立同步分(IID)的,违反了机器学习的基本假设。

3. ListWise

PointWise和PairWise都是直接学习每个样本是否相关,或者两个正负样本的相关关系,更像是metric learning的思路,都是试图通过抽样的学习试图推理出全局的排序结果,这种思路是有根本的劣势。而ListWise的基本思路是试图直接优化像NDCG的排序指标,从而学习到最佳的排序结果。

  • 实施细节

输入的一个sample的格式是query以及他所有的候选doc。如给定: qiq_i,和他的候选doc及标签:C(di1,..,dim)C(d_{i1},..,d_{im})Y(yi1,..,yim)Y(y_{i1},..,y_{im})。标签YY的值就是表示所有候选doc的顺序。比如某个候选集为{a,d,c,b,e}\{a,d,c,b,e\},如果就是自然顺序,其对应的标签为{5,2,3,4,1}\{5,2,3,4,1\}。 通过各种ListWise算法训练模型。 预测阶段;根据得分来排序。

  • ListWise三种基本思路:
  1. 第一种为Measure-specific

这种方法就是直接对比如NDCG这样的指标优化。 这种方法是典型的“理想很丰满,现实很骨干”,因为NDCG、MAP和AUC这类排序指标,他们在数学形式上,是“不连续”(Non-Continuous)的,以及“不可微”(Non-Differentiable)的,基于这个现实,通常有三种解决办法: 第一种方法:找到一个近似NDCG指标的“连续”和“可微”的替代函数,通过最优化这个替代函数来优化NDCG。代表算法:SoftRank 和 AppRank。 第二种方法:尝试从数学上写出一个NDCG等指标的“边界”,然后优化这个“边界”。比如,如果推导出一个上界,那就可以通过最小化这个上界来优化 NDCG。代表算法:SVM-MAP 和 SVM-NDCG。 第三种方法:直接优化算法,可以用来处理“不连续”和“不可微”的NDCG类指标。代表算法:AdaRank 和 RankGP。 2. 第二种为Non-Measure-specific 这种方法是根据一个已知的最优排序,尝试重建这个顺序,然后衡量两者的差距,即优化模型来试图减少这个差距,比如使用KL散度作为Loss。 代表算法:ListNet 和 ListMLE 3. 第三种,ListWise和PairWise结合的算法 这类方法的核心目标仍然是优化NDCG类的排序指标,设计出一种替代的目标函数,有了替代函数之后,优化和计算过程直接使用某种PairWise的方式处理。 代表算法: LambdaRank 和 LambdaMART。

  • ListWise的优缺点
    1. 在很多场景构造训练数据比较困难。
    2. 因为要计算排序的loss,通常计算复杂度更高。
    3. 在有充足质量好的数据基础上,ListWise相比较PairWise和PointWise,直接对目标任务,也就是排序,进行学习和优化,往往表现更好。

4. 常用评价指标

nDCG

关于nDCG的解释

Mean Average Precision(MAP)

排序任务中,每个query都会有一个排序列表。顾名思义,MAP,就是测试集上所有query的AP的平均,那我们先看一下AP:

AP(π,l)=k=1mP@kI{lπ1(k)=1}m1AP(\pi,l)=\frac{\sum^m_{k=1}{P@k*I_{\{ l_{\pi^{-1}(k)}=1\}}}}{m_1}

其中,π\pi表示item list,即推送的结果列表。 m表示结果列表总数量,m1m_1表示结果列表中与query相关的item数量。I{lπ1(k)=1}I_{\{l_{\pi^{-1}(k)}=1\}},表示排在位置k处的标签是否相关,1表示相关,0表示不相关。 P@kP@k就是topk的Precision: P@k(π,l)=t<=kI{lπ1(k)=1}kP@k(\pi,l)=\frac{\sum_{t<=k}{I_{\{ l_{\pi^{-1}(k)}=1\}}}}{k}

另附一张图讲的很清楚: map.png-176.4kB

代码实现:

def _ap(ranked_list, ground_truth):
    # ranked_list: 结果列表,如['a', 'b', 'd', 'c', 'e']
    # ground_truth: 相关的item列表,如 ['a', 'd']
    hits = 0
    sum_precs = 0
    for n, item in enumerate(ranked_list):
        if item in ground_truth:
            hits += 1
            sum_precs += hits / (n + 1.0)
    return sum_precs / max(1.0, len(ground_truth))

复制代码

  • 参考
  1. 搜索评价指标——NDCG
文章分类
人工智能