Skip-Gram 使用中心词 w(t) 预测上下文词 ...,w(t-2),w(t-1),w(t+1),w(t+2),..,可以建立中心词与上下文词之间的共现关系,即 ,其中 。 其中负采样的任务目标是最大化中心词与其上下文词的共现概率,最小化它与非其上下文词的共现概率。
Skip-Gram 模型
公式推导
以下推理公式涉及 sigmoid 函数 ,它满足运算 ;涉及 log 函数,它满足运算 。
-
使用 函数表示中心词 与其上下文词 的共现概率,其中用 D = 1 表示共现:
上述 是 的词向量, 是 的词向量转置(转置就是行列互换)。 -
使用 函数表示中心词 与非其上下文词 的不共现概率,其中用 D = 0 表示不共现:
上述 是 的词向量, 是 的词向量转置(转置就是行列互换)。 -
下述公式近似表示中心词 与 的对数概率,符号 和 分别表示累乘和累加:
-
下述公式表示中心词 与 的损失函数,求最大化概率等同最小化损失误差:
代码实现
查看完整代码可点击 使用 PyTorch 实现基于 Skip-Gram 的 Word2vec 模型 - 掘金 (juejin.cn),下述结合公式推导分析 SkipGramModel.forward 方法的实现。
def forward(self, center_in, positive_out, negative_out):
# 获取表示中心词的词向量
center_emb_in = self.embedding(center_in)
# 获取表示正样本(其中心词的上下文词)的词向量
positive_emb_out = self.embedding_out(positive_out)
# 获取表示负样本(非其中心词的上下文词)的词向量
negative_emb_out = self.embedding_out(negative_out)
# 转置中心词的词向量,以满足下述矩阵相乘的前提条件(第二个矩阵的行数等于第一个矩阵的列数),
# 转置前 center_emb_in 的 shape 属性是 (批次大小,1,词向量维度大小),
# 转置后 center_emb_in 的 shape 属性是 (批次大小,词向量维度大小,1)
center_emb_in = torch.transpose(center_emb_in, dim0=2, dim1=1)
# 正样本的词向量与中心词的词向量做矩阵乘法运算, bmm 是 PyTorch 中批量矩阵乘法的操作,
# positive_emb_out 的 shape 属性是 (批次大小,正样本数,词向量维度大小),
# 批量矩阵乘法运算后 positive_prob 的 shape 属性是 (批次大小,正样本数,1)
positive_prob = torch.bmm(positive_emb_out, center_emb_in)
# 压缩 positive_prob 维度 2 后的 shape 属性是 (批次大小,正样本数)
positive_prob = torch.squeeze(positive_prob, dim=2)
# 下述 logsigmoid 操作对应公式中的 log σ(Vo ⋅ VcT)
positive_prob = nn.functional.logsigmoid(positive_prob)
# 公式中 log σ(-Vneg ⋅ VcT) 有负号,所以先对负样本的词向量做取反操作
negative_emb_out = torch.neg(negative_emb_out)
# 负样本的词向量与中心词的词向量做矩阵乘法运算, bmm 是 PyTorch 中批量矩阵乘法的操作,
# negative_emb_out 的 shape 属性是 (批次大小,负样本数,词向量维度大小),
# 批量矩阵乘法运算后 negative_prob 的 shape 属性是 (批次大小,负样本数,1)
negative_prob = torch.bmm(negative_emb_out, center_emb_in)
# 压缩 negative_prob 维度 2 后的 shape 属性是 (批次大小,负样本数)
negative_prob = torch.squeeze(negative_prob, dim=2)
# 下述 logsigmoid 操作对应公式中的 log σ(-Vneg ⋅ VcT)
negative_prob = nn.functional.logsigmoid(negative_prob)
# 下述 sum 操作对应公式中 ∑ 累加操作
positive_prob = torch.sum(positive_prob, dim=1)
negative_prob = torch.sum(negative_prob, dim=1)
# 对应中心词与其上下文词的对数概率 log σ(Vo ⋅ VcT) + ∑ log σ(-Vo ⋅ VnegT)
prob = positive_prob + negative_prob
# 求最大化概率等同最小化损失误差,所以下述把概率取反操作后的结果当作损失误差
loss = torch.neg(prob)
loss = torch.mean(loss)
return loss
- 代码 torch.bmm(positive_emb_out, center_emb_in) 对应公式
- 代码 torch.bmm(torch.neg(negative_emb_out), center_emb_in) 对应公式
- 代码 torch.neg(prob) 对应公式
由于个人水平限制,难免会有错误。如果发现错误,希望指点出来,共同进步。