本文已参与「新人创作礼」活动,一起开启掘金创作之路。
如果直接调包的话很简单,例子如下:
import torch
import torch.nn as nn
torch.manual_seed(1234)
ce_loss = nn.CrossEntropyLoss()
x_input = torch.randn(3,3)
y_target = torch.tensor([1,2,0])
print(ce_loss(x_input, y_target))
这里的x_input可以理解为我们的网络预测结果,而y_target为真值,此处例子的结果是0.8526。那么自己做的话,考虑交叉熵的公式如下: 其中为总类别数,为符号函数(只能为0或1),为预测得到的概率(即此处的网络预测结果x_input)。
首先为了方便"矩阵运算",我们将y给展平为列向量的形式:
y_target = y_target.view(-1, 1)
对于x,将其做softmax压缩至0-1范围内后再进行log运算:
x_input = F.log_softmax(x_input, 1)
接着,利用pytorch的gather函数,找到各标签y所对应的x:
x_input.gather(1, y_target)
这一步是相对最难理解的。回到公式,考虑到我们已经算完了,现在要做的也就是找到x_input中各行所对应的独热的值。例如,这里y_target为:
tensor([[1],
[2],
[0]])
意思就是对于x的第一行,取第1个值;对于x的第二行,取第2个值;对于x的第三行,取第0个值。因为x为:
tensor([[-1.0207, -0.6645, -2.0784],
[-1.0184, -1.8474, -0.7315],
[-1.1617, -0.6996, -1.6595]])
因此取到的结果为:
tensor([[-0.6645],
[-0.7315],
[-1.1617]])
这一过程恰好是可以通过torch中的gather方法解决的。
最后求均值得到最终结果:
res = -1 * res
print(res.mean())
可以发现结果也为0.8526,完整代码如下:
import torch
import torch.nn.functional as F
torch.manual_seed(1234)
x_input = torch.randn(3,3)
y_target = torch.tensor([1,2,0])
y_target = y_target.view(-1, 1)
x_input = F.log_softmax(x_input, 1)
res = x_input.gather(1, y_target)
res = -1 * res
print(res.mean())
参考: zhuanlan.zhihu.com/p/98785902 www.jianshu.com/p/0c159cdd9…