本文已参与「新人创作礼」活动,一起开启掘金创作之路。
引言
本文旨在对pytorch中常用于分类问题的损失函数BinaryCrossEntropy(), CrossEntropy()用法进行一个简要的介绍。常见的文章主要是对这些损失函数的原理进行了数学推导,而本文主要介绍了其输入输出的shape和格式要求,作为一个工具存在。
损失函数
本文涉及到的损失函数有BCELoss()、BCEWithLogitsLoss()、NLLLOSS()、CrossEntropyLoss(),前两者是二分类问题常用的损失函数,后两者是多分类问题常用的损失函数。列出格式表如下:
| 输入格式 | label的dtype | 是否为独热向量 | 网络输出是否需要激活 | |
|---|---|---|---|---|
| BCELoss() | (pred:[*],label:[*],二者相同即可) | torch.float32 | 否 | 是 |
| BCEWithLogitsLoss() | (pred:[*],label:[*],二者相同即可) | torch.float32 | 否 | 否 |
| NLLLOSS() | (pred:[N,C],label:[N,]) | torch.int64 | 否 | 是 |
| CrossEntropyLoss() | (pred:[N,C],label:[N,])或 (pred:[N,C],label:[N,C]) | torch.int64/torch.float32,torch.float64 | 否 | 否 |
有两点需要额外解释一下:
- BCELoss可接受任意形状的输入,不止局限于我们所常用的(N,)或者(N,1),也就是并不只局限与认为batch维度上每个元素只存在单一样本,而是可以为(N,M)这种。
- 其中CrossEntropyLoss之所以会有label为[N,C]形状却不并不为onehot向量,这是因为这里的label描述的是一个样本属于多个类别的情况,可以认为是属于每一种类别的可能性,也可以认为是软化的onehot向量。
计算方式
(默认在batch上采用平均):
BCELoss()
,其中为实际标签,为网路预测其属于正样本的输出值(非概率)。
BCEWithLogitsLoss()
,其中为实际标签,为网路预测其属于正样本的概率。
NLL loss
,其中为指示函数,当第个样本的标签与当前类别c相同时取1,否则取0;为网络输出的第个样本属于第类的概率。
CrossEntropyLoss()
,其中为指示函数,当第个样本的标签与当前类别c相同时取1,否则取0;为网络输出的第个样本属于第类的输出值(非概率)。
代码验证
# BCELoss错误dytpe:
label=torch.randint(0,2,(32,16))
pred_prob=torch.rand((32,16))
loss=nn.functional.binary_cross_entropy(pred_prob,label)
#BCELoss正确dtype:
label=label.type(torch.float32)
loss=nn.functional.binary_cross_entropy(pred_prob,label)
#BCEWithLogitLloss错误dtype:
label=torch.randint(0,2,(32,16))
pred_prob=torch.rand((32,16))
loss=nn.functional.binary_cross_entropy_with_logits(pred_prob,label)
#BCEWithLogitLloss正确dtype:
label=label.type(torch.float32)
loss=nn.functional.binary_cross_entropy_with_logits(pred_prob,label)
# 验证BCEloss和BCEWithLogitLloss关系:加上激活函数Sigmoid
loss2=nn.functional.binary_cross_entropy(torch.sigmoid(pred_prob),label)
loss==loss2
# NLL loss正确dytpe:
pred_prob=torch.rand((32,5))
label=torch.randint(0,5,(32,))
loss=nn.functional.nll_loss(pred_prob,label)
# NLL loss错误dytpe:
label=label.type(torch.float32)
loss=nn.functional.nll_loss(pred_prob,label)
label=label.type(torch.int32)
loss=nn.functional.nll_loss(pred_prob,label)
# CrossEntropyLoss正确dytpe:
pred_prob=torch.rand((32,5))
label=torch.randint(0,5,(32,))
loss=nn.functional.cross_entropy(pred_prob,label)
# CrossEntropyLoss错误dytpe:
label=label.type(torch.float32)
loss=nn.functional.cross_entropy(pred_prob,label)
label=label.type(torch.int32)
loss=nn.functional.cross_entropy(pred_prob,label)
# 验证CrossEntropyLoss和NLL loss关系:加上激活函数softmax后取对数
loss2=nn.functional.nll_loss(torch.log(torch.softmax(pred_prob,dim=-1)),label)
loss=loss2