本文已参与「新人创作礼」活动,一起开启掘金创作之路。
写这篇博客的原因是看WeNet源码的Attention计算损失代码里面有CE loss和KL loss,所以对此进行总结,计算Attention loss的代码如下:
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Compute loss between x and target.
The model outputs and data labels tensors are flatten to
(batch*seqlen, class) shape and a mask is applied to the
padding part which should not be calculated for loss.
Args:
x (torch.Tensor): prediction (batch, seqlen, class)
target (torch.Tensor):
target signal masked with self.padding_id (batch, seqlen)
Returns:
loss (torch.Tensor) : The KL loss, scalar float value
"""
assert x.size(2) == self.size
batch_size = x.size(0)
x = x.view(-1, self.size)
target = target.view(-1)
# use zeros_like instead of torch.no_grad() for true_dist,
# since no_grad() can not be exported by JIT
true_dist = torch.zeros_like(x)
true_dist.fill_(self.smoothing / (self.size - 1))
ignore = target == self.padding_idx # (B,)
total = len(target) - ignore.sum().item()
target = target.masked_fill(ignore, 0) # avoid -1 index
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
denom = total if self.normalize_length else batch_size
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
其实代码里调用了pytorch的KLDivLoss
因为Attention的输出是自回归的形式,所以和output天然可以用交叉熵损失,那么先来看看什么是交叉熵损失。
交叉熵损失 Cross Entropy Loss
这里我想说熵的含义是你能得到多少有用的信息量,它是一个很好的衡量指标,如果事情的不确定性越大,别人给你一个关于它的信息,熵的值就越大;如果这件事没什么不确定性,别人告诉你关于这个事的信息,信息量就几乎没有。
熵的公式定义如下:
交叉熵可以理解为真实的预测分布p和预测分布q之间的差距的函数,它的公式如下:
可以看出只有里的部分变了,熵的部分是真实的分布,而交叉熵中是预测分布q,通过交叉熵可以计算出真实分布和预测分布的差距,如果预测的完全和真实分布一样,那么这里交叉熵等于普通的熵,而真实情况是几乎不可能完全预测的和真实分布完全一样,所以交叉熵会比熵大一些信息位。其中交叉熵超过熵的部分称为相对熵或者KL散度。公式如下:
(KL-Divergence是KL散度的意思,中间不是减号)
WeNet中注意力部分的损失函数代码讲解
这里我们只分析KLDivLoss,至于WeNet的逐行分析我会另开一贴juejin.cn/post/709277…
KLDivLoss
对于包含N个样本的batch数据 ,是神经网络的输出,假如WeNet中是(16,13,4233 ),并且进行了归一化和对数化;是真实的标签(默认为概率),与同维度。
第个样本的损失值计算如下:
class KLDivLoss(_Loss):
__constants__ = ['reduction']
def __init__(self, size_average=None, reduce=None, reduction='mean'):
super(KLDivLoss, self).__init__(size_average, reduce, reduction)
def forward(self, input, target):
return F.kl_div(input, target, reduction=self.reduction)
pytorch中通过torch.nn.KLDivLoss
类实现,也可以直接调用F.kl_div
函数,代码中的size_average
与reduce
已经弃用。reduction有四种取值mean
,batchmean
, sum
, none
,对应不同的返回。 默认为mean
代码中 kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
会调用self.criterion = nn.KLDivLoss(reduction="none")
,WeNet这里使用默认的,也就是每个损失都列出来。
总结
交叉熵等于熵加KL散度,交叉熵是计算真实的后验分布和预测分布之间的差距的一个函数,如果预测分布和真实分布越接近,交叉熵损失就会越小,而深度学习的目标就是降低这个损失。