Word2vec 中 Skip-Gram 模型负采样方法的原理

145 阅读3分钟

Skip-Gram 使用中心词 w(t) 预测上下文词 ...,w(t-2),w(t-1),w(t+1),w(t+2),..,可以建立中心词与上下文词之间的共现关系,即 P(wt+jwt)P(w_{t+j}|w_{t}),其中 j{±1,..,±k}j \in \left\{ \pm 1,..,\pm k \right\}。 其中负采样的任务目标是最大化中心词与其上下文词的共现概率,最小化它与非其上下文词的共现概率。

image.png

Skip-Gram 模型

公式推导

以下推理公式涉及 sigmoid 函数 σ(u)=11+eu\sigma(u) = \frac{1}{1 + e^{-u}} ,它满足运算 σ(u)=1σ(u)\sigma(-u) = 1 - \sigma(u);涉及 log 函数,它满足运算 log(ab)=log(a)+log(b)log(a*b) = log(a) + log(b)

  • 使用 σ\sigma 函数表示中心词 wcw_{c} 与其上下文词 wow_{o} 的共现概率,其中用 D = 1 表示共现:
    P(D=1wo,wc)=σ(vovcT)P(D=1 | w_{o},w_{c}) = \sigma(v_{o} · v_{c}^{T})
    上述 vov_{o}wow_{o} 的词向量,vcTv_{c}^{T}wcw_{c} 的词向量转置(转置就是行列互换)。

  • 使用 σ\sigma 函数表示中心词 wcw_{c} 与非其上下文词 wnegw_{neg} 的不共现概率,其中用 D = 0 表示不共现:
    P(D=0wneg,wc)=1P(D=1wnen,wc)P(D=0 | w_{neg},w_{c}) = 1- P(D=1 | w_{nen},w_{c})
    P(D=0wneg,wc)=1σ(vnegvcT)\Rightarrow P(D=0 | w_{neg},w_{c}) = 1 - \sigma(v_{neg} · v_{c}^{T})
    P(D=0wneg,wc)=σ(vnegvcT)\Rightarrow P(D=0 | w_{neg},w_{c}) = \sigma(-v_{neg} · v_{c}^{T})
    上述 vnegv_{neg}wnegw_{neg} 的词向量,vcTv_{c}^{T}wcw_{c} 的词向量转置(转置就是行列互换)。

  • 下述公式近似表示中心词 wcw_{c}wow_{o} 的对数概率,符号 \prod\sum 分别表示累乘和累加:
    logP(wowc)=log(P(D=1wo,wc)wjwnegP(D=0wj,wc))log\,P(w_{o} | w_{c}) = log(P(D=1 | w_{o},w_{c}) \quad · \prod\limits_{w_{j} \in w_{neg}} P(D=0 | w_{j},w_{c}))
    logP(wowc)=logP(D=1wo,wc)+wjwneglogP(D=0wj,wc)\Rightarrow log\,P(w_{o} | w_{c}) = log\,P(D=1 | w_{o},w_{c}) \quad + \sum\limits_{w_{j} \in w_{neg}} log\,P(D=0 | w_{j},w_{c})
    logP(wowc)=logσ(vovcT)+vjvneglogσ(vjvcT)\Rightarrow log\,P(w_{o} | w_{c}) = log\,\sigma(v_{o} · v_{c}^{T}) \: + \sum\limits_{v_{j} \in v_{neg}} log\,\sigma(-v_{j} · v_{c}^{T})

  • 下述公式表示中心词 wcw_{c}wow_{o} 的损失函数,求最大化概率等同最小化损失误差:
    E=logP(wowc)E = -log\,P(w_{o} | w_{c})
    E=logσ(vovcT)vjvneglogσ(vjvcT)\Rightarrow E = -log\, \sigma(v_{o} · v_{c}^{T}) \: - \sum\limits_{v_{j} \in v_{neg}} log\,\sigma(-v_{j} · v_{c}^{T})

代码实现

查看完整代码可点击 使用 PyTorch 实现基于 Skip-Gram 的 Word2vec 模型 - 掘金 (juejin.cn),下述结合公式推导分析 SkipGramModel.forward 方法的实现。

def forward(self, center_in, positive_out, negative_out):
    # 获取表示中心词的词向量
    center_emb_in = self.embedding(center_in)
    # 获取表示正样本(其中心词的上下文词)的词向量
    positive_emb_out = self.embedding_out(positive_out)
    # 获取表示负样本(非其中心词的上下文词)的词向量
    negative_emb_out = self.embedding_out(negative_out)
    
    # 转置中心词的词向量,以满足下述矩阵相乘的前提条件(第二个矩阵的行数等于第一个矩阵的列数),
    # 转置前 center_emb_in 的 shape 属性是 (批次大小,1,词向量维度大小),
    # 转置后 center_emb_in 的 shape 属性是 (批次大小,词向量维度大小,1)
    center_emb_in = torch.transpose(center_emb_in, dim0=2, dim1=1)

    # 正样本的词向量与中心词的词向量做矩阵乘法运算, bmm 是 PyTorch 中批量矩阵乘法的操作,
    # positive_emb_out 的 shape 属性是 (批次大小,正样本数,词向量维度大小),
    # 批量矩阵乘法运算后 positive_prob 的 shape 属性是 (批次大小,正样本数,1)
    positive_prob = torch.bmm(positive_emb_out, center_emb_in)  
    # 压缩 positive_prob 维度 2 后的 shape 属性是 (批次大小,正样本数)
    positive_prob = torch.squeeze(positive_prob, dim=2)
    # 下述 logsigmoid 操作对应公式中的 log σ(Vo ⋅ VcT)
    positive_prob = nn.functional.logsigmoid(positive_prob)
    
    # 公式中 log σ(-Vneg ⋅ VcT) 有负号,所以先对负样本的词向量做取反操作
    negative_emb_out = torch.neg(negative_emb_out)
    # 负样本的词向量与中心词的词向量做矩阵乘法运算, bmm 是 PyTorch 中批量矩阵乘法的操作,
    # negative_emb_out 的 shape 属性是 (批次大小,负样本数,词向量维度大小),
    # 批量矩阵乘法运算后 negative_prob 的 shape 属性是 (批次大小,负样本数,1)
    negative_prob = torch.bmm(negative_emb_out, center_emb_in)
    # 压缩 negative_prob 维度 2 后的 shape 属性是 (批次大小,负样本数)
    negative_prob = torch.squeeze(negative_prob, dim=2)
    # 下述 logsigmoid 操作对应公式中的 log σ(-Vneg ⋅ VcT)
    negative_prob = nn.functional.logsigmoid(negative_prob)
    
    # 下述 sum 操作对应公式中 ∑ 累加操作
    positive_prob = torch.sum(positive_prob, dim=1)
    negative_prob = torch.sum(negative_prob, dim=1)
    # 对应中心词与其上下文词的对数概率 log σ(Vo ⋅ VcT) + ∑ log σ(-Vo ⋅ VnegT)
    prob = positive_prob + negative_prob
    # 求最大化概率等同最小化损失误差,所以下述把概率取反操作后的结果当作损失误差
    loss = torch.neg(prob)
    loss = torch.mean(loss)

    return loss
  • 代码 torch.bmm(positive_emb_out, center_emb_in) 对应公式 vovcTv_{o} · v_{c}^{T}
  • 代码 torch.bmm(torch.neg(negative_emb_out), center_emb_in) 对应公式 vnegvcT-v_{neg} · v_{c}^{T}
  • 代码 torch.neg(prob) 对应公式 E=logP(wowc)E = -log P(w_{o}|w_{c})

由于个人水平限制,难免会有错误。如果发现错误,希望指点出来,共同进步。

参考资料