论文阅读 (007): Deep Learning CTR Prediction (DeepFM)

470 阅读3分钟

阅读论文--DeepFM: A Factorization-Machine based Neural Network for CTR Prediction

背景

在CTR任务中,特征的叉乘对于预估任务十分重要,比如用户常常在吃饭时间下载食品相关的app,即app类型和时间的特征叉乘(二阶)对CTR预估重要;男性年轻用户喜欢射击和RPG游戏,即app类型,用户性别和用户年龄的特征叉乘(三阶)对CTR预估十分重要。

模型针对的改进有以下几点

  • 在FM模型中,只考虑到了特征的二阶叉乘,DeepFM使用DNN学习的embedding作为特征的隐向量(FM中的latent vector),可以表示特征的多阶叉乘
  • 在Wide & Deep模型中,使用了DNN提高模型的Generalization,但是需要人工设计特征,DeepFM省略人工设计特征

模型

DeepFM模型如下图所示,总体结构和Wide & Deep十分类似,区别点在于使用FM代替了LR模型,使用dense embedding作为latent vector。最终模型的输出为y^=sigmoid(yFM+yDNN)\hat{y} = sigmoid(y_{FM}+y_{DNN})

模型中有几点需要说明,从下到上开始分析,Saprse Features中输入为Field i,其中一个特征作为一个field,输入特征为x=[xfield1,xfield2,...,xfieldj,...,xfieldm]x = [x_{field_1}, x_{field_2},...,x_{field_j},...,x_{field_m}] ,这么表示的目的是节省存储空间(使用one-hot存储占很大空间)。因此一个xfieldjx_{field_j}学习到一个embedding,学习的模块如下所示

FM模型为yFM=w,x+j1=1dj2=j1+1dVi,Vjxj1xj2y_{F M}=\langle w, x\rangle+\sum_{j_{1}=1}^{d} \sum_{j_{2}=j_{1}+1}^{d}\left\langle V_{i}, V_{j}\right\rangle x_{j_{1}} \cdot x_{j_{2}},文中使用embedding作为ViV_{i}VjV_{j},因此embedding的作用为

  1. 将不同长度field映射为统一长度的embedding
  2. 使用DNN学习时,即使部分特征没在样本中出现,也可以学习到较好的语义信息,捕获更高阶的信息,提升FM的效果

和常见CTR模型的特点对比如下

实验

实验数据为:Criteo Dataset、公司数据,实验默认设置为:

  • dropout:0.5
  • network structure:400-400-400(这边应该是deep部分连接层神经元个数)
  • optimizer:adam

效率对比

对比指标为:training time of deep CTR modeltraining time of LR\frac{|training\ time\ of\ deep\ CTR\ model|}{|training\ time\ of\ LR|},实验结果如下所示

结论有

  1. FNN预训练需要较多的时间,导致总体效率较低
  2. IPNN和PNN*在GPU上相对CPU速度大大提升,但是总时间相对还是较慢,因为需要较多的内积运算
  3. DeepFM在CPU和GPU上效率几乎最高

效果对比

对比模型的AUC和LogLoss,效果如下所示

结论有

  1. 使用特征的叉乘可以大大提升CTR预估结果,在所有模型中,LR模型效果最差
  2. 学习高阶和低阶的特征交叉可以同时提升模型结果,DeepFM效果比FM(只考虑低阶的特征交叉)效果好,DeepFM比FNN, IPNN, OPNN, PNN∗(只考虑高阶的特征交叉)效果好。
  3. 高阶特征交叉和低阶的特征交叉共享同一个embedding时可以提升预估结果,相对于LR & DNN和FM & DNN(embedding是分开学习的)效果好

实验还对比了激活函数、Dropout、DNN中神经元个数、DNN中隐藏层个数、DNN形状等超参数设计,具体可以看文章,此处省略。

讨论

本人觉得文章亮点在于使用DNN学习embedding来捕捉特征之间的语义信息,并将embedding作为FM中latent vector。实现了特征高阶和低阶的交叉表示,以及可以省略人工设计特征。

参考资料

  1. 深度推荐模型之DeepFM
  2. 推荐系统 - DeepFM架构详解