1. 求导
1.1 定义基本公式
首先,定义 Softmax 函数 pi 和 交叉熵损失 L。对于输入向量 z:
pi=∑jexp(zj)exp(zi)
L=−k∑yklog(pk)
其中:
- yk 是真实标签(常数,比如 0 或 1)。
- pk 是模型预测的概率(变量)。
- ∑k 表示对所有类别求和。
其中 yk 是真实标签(在单分类中,仅有一个 yk=1,其余为 0)。
1.2 应用链式法则
我们需要计算损失 L 对每个输入 zi 的偏导数。根据链式法则:
∂zi∂L=k∑∂pk∂L∂zi∂pk
我们要计算 L 对某一个特定预测值 pk 的偏导数 ∂pk∂L。
根据求导的线性性质(和的导数等于导数的和),我们只需要对求和符号中包含 pk 的那一项求导即可,其他项相对于 pk 都是常数,导数为 0。
∂pk∂L=∂pk∂(−ykln(pk))
利用对数求导公式 dxdln(x)=x1:
∂pk∂L=−yk⋅pk1=−pkyk
1.3 计算 Softmax 的偏导数
计算 ∂zi∂pk 时需要分两种情况: 当 k=i 时(即 Softmax 输出与分母中的 Logit 下标相同):∂zi∂pi=pi(1−pi)当 k=i 时(下标不同):∂zi∂pk=−pkpi
情况1:i=k
已知:
pi=Sezi其中 S=∑ezj
两边直接取自然对数 ln:
lnpi=ln(ezi)−lnS=zi−lnS
对等式两边同时关于 zi 求导。
-
左边:根据链式法则 dxdln(u)=u1⋅u′,得到:pi1⋅∂zi∂pi
-
右边:
◦ zi 的导数是 1。
◦ lnS 的导数是 S1⋅∂zi∂S。因为 S=ez1+ez2+…,所以 ∂zi∂S=ezi。
◦ 故右边为:
1−Sezi
由于 Sezi 正好就是 pi,所以等式变为:
pi1⋅∂zi∂pi=1−pi
把左边的 pi 乘到右边:
∂zi∂pi=pi(1−pi)
情况2:k=i
我们看 pk 的定义,此时 k=i:
pk=∑jezjezk
对 zi 求导时,pk 的分子 ezk 是一个常数(因为 k=i),变量 zi 只出现在分母里。
根据倒数求导规则 dxd(u1)=−u21⋅u′:
- 分母的导数是 ezi。
- 整个式子的导数就是:
∂zi∂pk=ezk⋅(−(∑ezj)21)⋅ezi
我们将上面的结果拆开,凑成我们熟悉的 p:
∂zi∂pk=−(∑ezjezk)⋅(∑ezjezi)
左边那一坨就是 pk,右边那一坨就是 pi。
1.4 合并并化简
将上述结果代入链式法则公式:
∂zi∂L=(−piyi)pi(1−pi)+k=i∑(−pkyk)(−pkpi)
化简得:
∂zi∂L=−yi+yipi+k=i∑ykpi=−yi+pik∑yk
由于真实标签 ∑yk=1(所有类别的概率和为 1),最终公式简化为:
∂zi∂L=pi−yi
这个简洁的结果意味着,梯度的大小仅仅是预测概率与真实目标之间的差距。
2. logSoftMax
LogSoftmax 就是对 Softmax 的输出结果直接取自然对数(ln)。 数学公式如下:
LogSoftmax(zi)=ln(∑jexp(zj)exp(zi))=zi−ln(j∑exp(zj))
在计算 ln∑exp(zj) 时,即使 zj 非常大,也不会发生数值溢出(Overflow)
为了防止 ezj 爆炸,我们先找出输入向量中的最大值 M=max(z),然后进行如下恒等变换:
lnj∑ezj=ln(j∑ezj−M+M)=ln(eMj∑ezj−M)
利用对数性质 ln(a⋅b)=lna+lnb:
lnj∑ezj=M+ln(j∑ezj−M)
LogSoftmax 的稳健实现公式为:
LogSoftmax(zi)=zi−(M+lnj∑ezj−M)
假设 L 是 LogSoftmax 的输出结果向量,z 是输入的 Logits:
Li=LogSoftmax(zi)=zi−lnj∑ezj
对 zi 求偏导的结果如下:
• 当 i=j 时: ∂zi∂Li=1−Pi
• 当 i=j 时: ∂zj∂Li=−Pj
其中 P 就是对应的 Softmax 概率值。
3. DSSM损失函数例子
们假设一个简单的搜索场景,并用数字模拟一遍 DSSM 计算损失函数的过程。
场景设定
• 查询 (Query): "如何做红烧肉"
• 正样本 (D+): "经典红烧肉做法大全"(用户点击过的)
• 负样本 (D1−,D2−): "清蒸鱼怎么做"、"手机维修教程"(随机抽取的)
• 平滑因子 (γ): 设定为 5(这是论文中常用的超参数)
3.1 模型输出向量并计算相似度
假设经过深度神经网络后,Query 和三个 Document 都变成了低维向量,我们计算它们之间的余弦相似度(范围在 −1 到 1 之间):
| 文档类型 | 文档标题 | 余弦相似度 cos(yQ,yD) |
|---|
| 正样本 (D+) | 经典红烧肉做法大全 | 0.8 (语义非常接近) |
| 负样本 (D1−) | 清蒸鱼怎么做 | 0.3 (同属烹饪,但不同菜) |
| 负样本 (D2−) | 手机维修教程 | -0.1 (完全无关) |
3.2乘以平滑因子 γ 并取指数 (Exp)
这一步是为了放大差异,也就是 Softmax 的分子部分:
• Exp(D+)=exp(5×0.8)=exp(4.0)≈54.60
• Exp(D1−)=exp(5×0.3)=exp(1.5)≈4.48
• Exp(D2−)=exp(5×−0.1)=exp(−0.5)≈0.61
3.3 计算 Softmax 概率
现在我们把这三个值归一化,算出模型认为每个文档是“正确答案”的概率。
• 分母 (Sum): 54.60+4.48+0.61=59.69
• 正样本概率 P(D+∣Q): 54.60/59.69≈0.915
• 负样本概率 P(D1−∣Q): 4.48/59.69≈0.075
• 负样本概率 P(D2−∣Q): 0.61/59.69≈0.010
3.4 计算损失 (Loss)
DSSM 的目标是让正样本的概率 P(D+∣Q) 越接近 1 越好。我们使用负对数似然:
L=−log(P(D+∣Q))=−log(0.915)≈0.089
• 如果预测很准(如上例):P(D+∣Q) 很大,Loss 就会很小(接近 0)。
• 如果预测很差:比如模型觉得“手机维修”更匹配,给正样本的概率只有 0.1,那么 L=−log(0.1)≈2.3,Loss 就会变大,从而通过反向传播惩罚模型。