Pytorch中交叉熵损失函数分析

467 阅读2分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

引言

本文旨在对pytorch中常用于分类问题的损失函数BinaryCrossEntropy(), CrossEntropy()用法进行一个简要的介绍。常见的文章主要是对这些损失函数的原理进行了数学推导,而本文主要介绍了其输入输出的shape和格式要求,作为一个工具存在。

损失函数

 本文涉及到的损失函数有BCELoss()、BCEWithLogitsLoss()、NLLLOSS()、CrossEntropyLoss(),前两者是二分类问题常用的损失函数,后两者是多分类问题常用的损失函数。列出格式表如下:

输入格式label的dtype是否为独热向量网络输出是否需要激活
BCELoss()(pred:[*],label:[*],二者相同即可)torch.float32
BCEWithLogitsLoss()(pred:[*],label:[*],二者相同即可)torch.float32
NLLLOSS()(pred:[N,C],label:[N,])torch.int64
CrossEntropyLoss()(pred:[N,C],label:[N,])或 (pred:[N,C],label:[N,C])torch.int64/torch.float32,torch.float64

 有两点需要额外解释一下:

  1. BCELoss可接受任意形状的输入,不止局限于我们所常用的(N,)或者(N,1),也就是并不只局限与认为batch维度上每个元素只存在单一样本,而是可以为(N,M)这种。
  2. 其中CrossEntropyLoss之所以会有label为[N,C]形状却不并不为onehot向量,这是因为这里的label描述的是一个样本属于多个类别的情况,可以认为是属于每一种类别的可能性,也可以认为是软化的onehot向量。

计算方式

(默认在batch上采用平均):

BCELoss()

loss=1NiN[yilog(pi)+(1yi)log(1pi)]loss=-\frac{1}{N}\sum_i^{N}[y_i\cdot log(p_i)+ (1-y_i)\cdot log(1-p_i)],其中yiy_i为实际标签,pip_i为网路预测其属于正样本的输出值(非概率)。

BCEWithLogitsLoss()

loss=1NiN[yilogσ(pi)+(1yi)log(1σ(pi))]loss=-\frac{1}{N}\sum_i^{N}[y_i\cdot log \sigma(p_i)+(1-y_i)\cdot log (1-\sigma (p_i))],其中yiy_i为实际标签,pip_i为网路预测其属于正样本的概率。

NLL loss

loss=1NiNjCI(yi=c)picloss=-\frac{1}{N}\sum_i^{N} \sum_j^CI (y_{i}=c)p_{ic},其中II为指示函数,当第ii个样本的标签yiy_{i}与当前类别c相同时取1,否则取0;picp_{ic}为网络输出的第ii个样本属于第cc类的概率。

CrossEntropyLoss()

loss=1NiNjclogexp(pc)jcexp(pc)I(yi=c)loss=-\frac{1}{N}\sum_i^{N}\sum_j^clog\frac{exp(p_c)}{\sum_j^cexp(p_c)}I (y_{i}=c),其中II为指示函数,当第ii个样本的标签yiy_{i}与当前类别c相同时取1,否则取0;picp_{ic}为网络输出的第ii个样本属于第cc类的输出值(非概率)。

代码验证

# BCELoss错误dytpe:
label=torch.randint(0,2,(32,16))
pred_prob=torch.rand((32,16))
loss=nn.functional.binary_cross_entropy(pred_prob,label)
#BCELoss正确dtype:
label=label.type(torch.float32)
loss=nn.functional.binary_cross_entropy(pred_prob,label)

#BCEWithLogitLloss错误dtype:
label=torch.randint(0,2,(32,16))
pred_prob=torch.rand((32,16))
loss=nn.functional.binary_cross_entropy_with_logits(pred_prob,label)
#BCEWithLogitLloss正确dtype:
label=label.type(torch.float32)
loss=nn.functional.binary_cross_entropy_with_logits(pred_prob,label)

# 验证BCEloss和BCEWithLogitLloss关系:加上激活函数Sigmoid
loss2=nn.functional.binary_cross_entropy(torch.sigmoid(pred_prob),label)
loss==loss2

# NLL loss正确dytpe:
pred_prob=torch.rand((32,5))
label=torch.randint(0,5,(32,))
loss=nn.functional.nll_loss(pred_prob,label)
# NLL loss错误dytpe:
label=label.type(torch.float32)
loss=nn.functional.nll_loss(pred_prob,label)
label=label.type(torch.int32)
loss=nn.functional.nll_loss(pred_prob,label)

# CrossEntropyLoss正确dytpe:
pred_prob=torch.rand((32,5))
label=torch.randint(0,5,(32,))
loss=nn.functional.cross_entropy(pred_prob,label)
# CrossEntropyLoss错误dytpe:
label=label.type(torch.float32)
loss=nn.functional.cross_entropy(pred_prob,label)
label=label.type(torch.int32)
loss=nn.functional.cross_entropy(pred_prob,label)

# 验证CrossEntropyLoss和NLL loss关系:加上激活函数softmax后取对数
loss2=nn.functional.nll_loss(torch.log(torch.softmax(pred_prob,dim=-1)),label)
loss=loss2