【Torch-RecHub学习】DIN实现

1,484 阅读2分钟

1. 本文简述

github.com/datawhalech… DataWhale团队发布的开源推荐系统工具包。目前实现了常用的LR、FM、MLP等经典模型的组件,但仍缺少应用GCN的组件。

本文学习使用torch-rechub工具包实现经典的DIN模型。

2. DIN简述

  • 2018年阿里在CCFA类会议KDD上发表
  • 创新点:引入具有局部激活单元(自己设计的类Attention),对序列物品赋予权重学习用户的动态兴趣
  • 创新点:两个训练工业级深度网络的方法(a) mini-batch aware regularizer(b)data adaptive activation function

论文下载:《Deep Interest Network for Click-Through Rate Prediction》

2.1 DIN的base model

DIN_base_model.png 上图为DIN的base model

基线做法(Embedding+MLP):把高维稀疏特征映射为低维的Embedding,然后把Embedding转化为固定长度的向量,拼接这些向量,传入全连接层FC。

缺陷:对于给定的用户,无论候选广告是什么,这个表示向量保持不变,维数有限的用户表示向量会成为表达用户多样性兴趣的瓶颈。 单纯扩大Embedding vector的大小会大大增加学习参数的大小。在有限的训练数据下会导致过拟合,对于工业在线系统是不能接受的。

2.2 引入局部激活单元local activation unit

目的:在有限的维度中用一个向量表示用户的不同兴趣

DIN_model.png 上图为DIN模型,和Activation unit的具体结构

把从Embedding层得到的含用户历史行为的Item_Embedding向量与广告AD_Embedding向量做外积out product,结合原始的Item_Emb和AD_Emb拼接后输入PRelu得到相关性分数 Attention机制的思想,表示两者之间的相关性,但是分数和不为1

2.3 补充

  • pooling层:把用户的历史行为Embedding变为定长向量,因为FC的输入需要固定长度
  • concat层:拼接所有的特征Embedding,作为MLP的输入
  • context features:相关的上下文特征

3.代码实现

class DIN(nn.Module):

    def __init__(self, features, history_features, target_features, mlp_params, attention_mlp_params):
        super().__init__()
        self.features = features
        self.history_features = history_features
        self.target_features = target_features
        self.num_history_features = len(history_features)
        self.all_dims = sum([fea.embed_dim for fea in features + history_features + target_features])

        self.embedding = EmbeddingLayer(features + history_features + target_features)
        self.attention_layers = nn.ModuleList(
            [ActivationUnit(fea.embed_dim, **attention_mlp_params) for fea in self.history_features])
        self.mlp = MLP(self.all_dims, activation="dice", **mlp_params)

    def forward(self, x):
        embed_x_features = self.embedding(x, self.features)  #(batch_size, num_features, emb_dim)
        embed_x_history = self.embedding(
            x, self.history_features)  #(batch_size, num_history_features, seq_length, emb_dim)
        embed_x_target = self.embedding(x, self.target_features)  #(batch_size, num_target_features, emb_dim)
        attention_pooling = []
        for i in range(self.num_history_features):
            attention_seq = self.attention_layers[i](embed_x_history[:, i, :, :], embed_x_target[:, i, :])
            attention_pooling.append(attention_seq.unsqueeze(1))  #(batch_size, 1, emb_dim)
        attention_pooling = torch.cat(attention_pooling, dim=1)  #(batch_size, num_history_features, emb_dim)

        mlp_in = torch.cat([
            attention_pooling.flatten(start_dim=1),
            embed_x_target.flatten(start_dim=1),
            embed_x_features.flatten(start_dim=1)
        ],
                           dim=1)  #(batch_size, N)

        y = self.mlp(mlp_in)
        return torch.sigmoid(y.squeeze(1))


class ActivationUnit(nn.Module):

    def __init__(self, emb_dim, dims=[36], activation="dice", use_softmax=False):
        super(ActivationUnit, self).__init__()
        self.emb_dim = emb_dim
        self.use_softmax = use_softmax
        self.attention = MLP(4 * self.emb_dim, dims=dims, activation=activation)

    def forward(self, history, target):
        seq_length = history.size(1)
        target = target.unsqueeze(1).expand(-1, seq_length, -1)  #batch_size,seq_length,emb_dim
        att_input = torch.cat([target, history, target - history, target * history],
                              dim=-1)  # batch_size,seq_length,4*emb_dim
        att_weight = self.attention(att_input.view(-1, 4 * self.emb_dim)) #(batch_size*seq_length,4*emb_dim)
        att_weight = att_weight.view(-1, seq_length)  #(batch_size*seq_length, 1) -> (batch_size,seq_length)
        if self.use_softmax:
            att_weight = att_weight.softmax(dim=-1)
        # (batch_size, seq_length, 1) * (batch_size, seq_length, emb_dim)
        output = (att_weight.unsqueeze(-1) * history).sum(dim=1)  #(batch_size,emb_dim)
        return output