多分类的交叉熵损失

4 阅读1分钟

1. 求导

1.1 定义基本公式 

首先,定义 Softmax 函数 pip_{i} 和 交叉熵损失 LL。对于输入向量 z\mathbf{z}

pi=exp(zi)jexp(zj)p_{i}=\frac{\exp (z_{i})}{\sum _{j}\exp (z_{j})}
L=kyklog(pk)L=-\sum _{k}y_{k}\log (p_{k})

其中: 

  • yky_{k} 是真实标签(常数,比如 0011)。
  • pkp_{k} 是模型预测的概率(变量)。
  • k\sum_{k} 表示对所有类别求和。

其中 yky_{k} 是真实标签(在单分类中,仅有一个 yk=1y_{k}=1,其余为 0)。

1.2 应用链式法则 

我们需要计算损失 LL 对每个输入 ziz_{i} 的偏导数。根据链式法则:

Lzi=kLpkpkzi\frac{\partial L}{\partial z_{i}}=\sum _{k}\frac{\partial L}{\partial p_{k}}\frac{\partial p_{k}}{\partial z_{i}}

我们要计算 LL 对某一个特定预测值 pkp_{k} 的偏导数 Lpk\frac{\partial L}{\partial p_{k}}。  根据求导的线性性质(和的导数等于导数的和),我们只需要对求和符号中包含 pkp_{k} 的那一项求导即可,其他项相对于 pkp_{k} 都是常数,导数为 00。 

Lpk=pk(ykln(pk))\frac{\partial L}{\partial p_{k}}=\frac{\partial }{\partial p_{k}}\left(-y_{k}\ln (p_{k})\right)

利用对数求导公式 ddxln(x)=1x\frac{d}{dx}\ln (x)=\frac{1}{x}: 

Lpk=yk1pk=ykpk\frac{\partial L}{\partial p_{k}}=-y_{k}\cdot \frac{1}{p_{k}}=-\frac{y_{k}}{p_{k}}

1.3 计算 Softmax 的偏导数

计算 pkzi\frac{\partial p_{k}}{\partial z_{i}} 时需要分两种情况: 当 k=ik=i 时(即 Softmax 输出与分母中的 Logit 下标相同):pizi=pi(1pi)\frac{\partial p_{i}}{\partial z_{i}}=p_{i}(1-p_{i})kik\ne i 时(下标不同):pkzi=pkpi\frac{\partial p_{k}}{\partial z_{i}}=-p_{k}p_{i}

情况1:i=ki = k

已知: pi=eziS其中 S=ezjp_{i}=\frac{e^{z_{i}}}{S}\quad \text{其中\ }S=\sum e^{z_{j}} 两边直接取自然对数 ln\ln lnpi=ln(ezi)lnS=zilnS\ln p_{i}=\ln (e^{z_{i}})-\ln S=z_{i}-\ln S

对等式两边同时关于 ziz_{i} 求导。 

  • 左边:根据链式法则 ddxln(u)=1uu\frac{d}{dx}\ln (u)=\frac{1}{u}\cdot u^{\prime },得到:1pipizi\frac{1}{p_{i}}\cdot \frac{\partial p_{i}}{\partial z_{i}}

  • 右边:

    ziz_{i} 的导数是 11

    lnS\ln S 的导数是 1SSzi\frac{1}{S}\cdot \frac{\partial S}{\partial z_{i}}。因为 S=ez1+ez2+S=e^{z_{1}}+e^{z_{2}}+\dots ,所以 Szi=ezi\frac{\partial S}{\partial z_{i}}=e^{z_{i}}

    ◦ 故右边为:

    1eziS1-\frac{e^{z_{i}}}{S}

由于 eziS\frac{e^{z_{i}}}{S} 正好就是 pip_{i},所以等式变为:

1pipizi=1pi\frac{1}{p_{i}}\cdot \frac{\partial p_{i}}{\partial z_{i}}=1-p_{i}

把左边的 pip_{i} 乘到右边:

pizi=pi(1pi)\frac{\partial p_{i}}{\partial z_{i}}=p_{i}(1-p_{i})

情况2:kik \ne i

我们看 pkp_{k} 的定义,此时 kik\ne i

pk=ezkjezjp_{k}=\frac{e^{z_{k}}}{\sum _{j}e^{z_{j}}}

ziz_{i} 求导时,pkp_{k} 的分子 ezke^{z_{k}} 是一个常数(因为 kik\ne i),变量 ziz_{i} 只出现在分母里。  根据倒数求导规则 ddx(1u)=1u2u\frac{d}{dx}(\frac{1}{u})=-\frac{1}{u^{2}}\cdot u^{\prime }: 

  1. 分母的导数是 ezie^{z_{i}}
  2. 整个式子的导数就是:
pkzi=ezk(1(ezj)2)ezi\frac{\partial p_{k}}{\partial z_{i}}=e^{z_{k}}\cdot \left(-\frac{1}{(\sum e^{z_{j}})^{2}}\right)\cdot e^{z_{i}}

我们将上面的结果拆开,凑成我们熟悉的 pp

pkzi=(ezkezj)(eziezj)\frac{\partial p_{k}}{\partial z_{i}}=-\left(\frac{e^{z_{k}}}{\sum e^{z_{j}}}\right)\cdot \left(\frac{e^{z_{i}}}{\sum e^{z_{j}}}\right)

左边那一坨就是 pkp_{k},右边那一坨就是 pip_{i}

1.4 合并并化简

将上述结果代入链式法则公式:

Lzi=(yipi)pi(1pi)+ki(ykpk)(pkpi)\frac{\partial L}{\partial z_{i}}=\left(-\frac{y_{i}}{p_{i}}\right)p_{i}(1-p_{i})+\sum _{k\ne i}\left(-\frac{y_{k}}{p_{k}}\right)(-p_{k}p_{i})

化简得:

Lzi=yi+yipi+kiykpi=yi+pikyk\frac{\partial L}{\partial z_{i}}=-y_{i}+y_{i}p_{i}+\sum _{k\ne i}y_{k}p_{i}=-y_{i}+p_{i}\sum _{k}y_{k}

由于真实标签 yk=1\sum y_{k}=1(所有类别的概率和为 1),最终公式简化为:

Lzi=piyi\frac{\partial L}{\partial z_{i}}=p_{i}-y_{i}

这个简洁的结果意味着,梯度的大小仅仅是预测概率与真实目标之间的差距。

2. logSoftMax

LogSoftmax 就是对 Softmax 的输出结果直接取自然对数(ln\ln )。 数学公式如下:

LogSoftmax(zi)=ln(exp(zi)jexp(zj))=ziln(jexp(zj))\text{LogSoftmax}(z_{i})=\ln \left(\frac{\exp (z_{i})}{\sum _{j}\exp (z_{j})}\right)=z_{i}-\ln \left(\sum _{j}\exp (z_{j})\right)

在计算 lnexp(zj)\ln \sum \exp (z_{j}) 时,即使 zjz_{j} 非常大,也不会发生数值溢出(Overflow)

为了防止 ezje^{z_{j}} 爆炸,我们先找出输入向量中的最大值 M=max(z)M=\max (\mathbf{z}),然后进行如下恒等变换: 

lnjezj=ln(jezjM+M)=ln(eMjezjM)\ln \sum _{j}e^{z_{j}}=\ln \left(\sum _{j}e^{z_{j}-M+M}\right)=\ln \left(e^{M}\sum _{j}e^{z_{j}-M}\right)

利用对数性质 ln(ab)=lna+lnb\ln (a\cdot b)=\ln a+\ln b

lnjezj=M+ln(jezjM)\ln \sum _{j}e^{z_{j}}=M+\ln \left(\sum _{j}e^{z_{j}-M}\right)

LogSoftmax 的稳健实现公式为:

LogSoftmax(zi)=zi(M+lnjezjM)\text{LogSoftmax}(z_{i})=z_{i}-\left(M+\ln \sum _{j}e^{z_{j}-M}\right)

假设 LL 是 LogSoftmax 的输出结果向量,zz 是输入的 Logits:

Li=LogSoftmax(zi)=zilnjezjL_{i}=\text{LogSoftmax}(z_{i})=z_{i}-\ln \sum _{j}e^{z_{j}}

ziz_{i} 求偏导的结果如下:  • 当 i=ji=j 时: Lizi=1Pi\frac{\partial L_{i}}{\partial z_{i}}=1-P_{i}

• 当 iji\ne j 时: Lizj=Pj\frac{\partial L_{i}}{\partial z_{j}}=-P_{j} 

其中 PP 就是对应的 Softmax 概率值。

3. DSSM损失函数例子

们假设一个简单的搜索场景,并用数字模拟一遍 DSSM 计算损失函数的过程。  场景设定  • 查询 (Query): "如何做红烧肉"

• 正样本 (D+D^{+}): "经典红烧肉做法大全"(用户点击过的)

• 负样本 (D1,D2D_{1}^{-},D_{2}^{-}): "清蒸鱼怎么做"、"手机维修教程"(随机抽取的)

• 平滑因子 (γ\gamma ): 设定为 55(这是论文中常用的超参数)

3.1 模型输出向量并计算相似度 

假设经过深度神经网络后,Query 和三个 Document 都变成了低维向量,我们计算它们之间的余弦相似度(范围在 1-111 之间): 

文档类型 文档标题余弦相似度 cos(yQ,yD)\cos (y_{Q},y_{D})
正样本 (D+D^{+})经典红烧肉做法大全0.8 (语义非常接近)
负样本 (D1D_{1}^{-})清蒸鱼怎么做0.3 (同属烹饪,但不同菜)
负样本 (D2D_{2}^{-})手机维修教程-0.1 (完全无关)

3.2乘以平滑因子 γ\gamma 并取指数 (Exp) 

这一步是为了放大差异,也就是 Softmax 的分子部分:  • Exp(D+)=exp(5×0.8)=exp(4.0)54.60Exp(D^{+})=\exp (5\times 0.8)=\exp (4.0)\approx \mathbf{54.60}

Exp(D1)=exp(5×0.3)=exp(1.5)4.48Exp(D_{1}^{-})=\exp (5\times 0.3)=\exp (1.5)\approx \mathbf{4.48}

Exp(D2)=exp(5×0.1)=exp(0.5)0.61Exp(D_{2}^{-})=\exp (5\times -0.1)=\exp (-0.5)\approx \mathbf{0.61}

3.3 计算 Softmax 概率

现在我们把这三个值归一化,算出模型认为每个文档是“正确答案”的概率。  • 分母 (Sum): 54.60+4.48+0.61=59.6954.60+4.48+0.61=59.69

• 正样本概率 P(D+Q)P(D^{+}|Q): 54.60/59.690.91554.60/59.69\approx \mathbf{0.915}

• 负样本概率 P(D1Q)P(D_{1}^{-}|Q): 4.48/59.690.0754.48/59.69\approx 0.075

• 负样本概率 P(D2Q)P(D_{2}^{-}|Q): 0.61/59.690.0100.61/59.69\approx 0.010 

3.4 计算损失 (Loss) 

DSSM 的目标是让正样本的概率 P(D+Q)P(D^{+}|Q) 越接近 11 越好。我们使用负对数似然: 

L=log(P(D+Q))=log(0.915)0.089L=-\log (P(D^{+}|Q))=-\log (0.915)\approx \mathbf{0.089}

• 如果预测很准(如上例):P(D+Q)P(D^{+}|Q) 很大,Loss 就会很小(接近 0)。

• 如果预测很差:比如模型觉得“手机维修”更匹配,给正样本的概率只有 0.10.1,那么 L=log(0.1)2.3L=-\log (0.1)\approx 2.3,Loss 就会变大,从而通过反向传播惩罚模型。