Pytorch 手工复现交叉熵损失(Cross Entropy Loss)

724 阅读1分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

如果直接调包的话很简单,例子如下:

import torch
import torch.nn as nn

torch.manual_seed(1234)
ce_loss = nn.CrossEntropyLoss()
x_input = torch.randn(3,3)
y_target = torch.tensor([1,2,0])
print(ce_loss(x_input, y_target))

这里的x_input可以理解为我们的网络预测结果,而y_target为真值,此处例子的结果是0.8526。那么自己做的话,考虑交叉熵的公式如下:L=1NiLi=1Nic=1Myiclog(pic)L=\frac{1}{N} \sum_{i} L_{i}=-\frac{1}{N} \sum_{i} \sum_{c=1}^{M} y_{i c} \log \left(p_{i c}\right) 其中MM为总类别数,yicy_{ic}为符号函数(只能为0或1),picp_{ic}为预测得到的概率(即此处的网络预测结果x_input)。

首先为了方便"矩阵运算",我们将y给展平为列向量的形式:

y_target = y_target.view(-1, 1)

对于x,将其做softmax压缩至0-1范围内后再进行log运算:

x_input = F.log_softmax(x_input, 1)

接着,利用pytorch的gather函数,找到各标签y所对应的x:

x_input.gather(1, y_target)

这一步是相对最难理解的。回到公式,考虑到我们已经算完了log(pic)\log (p_{i c}),现在要做的也就是找到x_input中各行所对应的独热的值。例如,这里y_target为:

tensor([[1],
        [2],
        [0]])

意思就是对于x的第一行,取第1个值;对于x的第二行,取第2个值;对于x的第三行,取第0个值。因为x为:

tensor([[-1.0207, -0.6645, -2.0784],
        [-1.0184, -1.8474, -0.7315],
        [-1.1617, -0.6996, -1.6595]])

因此取到的结果为:

tensor([[-0.6645],
        [-0.7315],
        [-1.1617]])

这一过程恰好是可以通过torch中的gather方法解决的。

最后求均值得到最终结果:

res = -1 * res
print(res.mean())

可以发现结果也为0.8526,完整代码如下:

import torch
import torch.nn.functional as F

torch.manual_seed(1234)
x_input = torch.randn(3,3)
y_target = torch.tensor([1,2,0])
y_target = y_target.view(-1, 1)
x_input = F.log_softmax(x_input, 1)
res = x_input.gather(1, y_target)
res = -1 * res
print(res.mean())

参考: zhuanlan.zhihu.com/p/98785902 www.jianshu.com/p/0c159cdd9…