交叉熵损失和KL散度损失的关系

1,724 阅读3分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

写这篇博客的原因是看WeNet源码的Attention计算损失代码里面有CE loss和KL loss,所以对此进行总结,计算Attention loss的代码如下:

def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """Compute loss between x and target.

    The model outputs and data labels tensors are flatten to
    (batch*seqlen, class) shape and a mask is applied to the
    padding part which should not be calculated for loss.

    Args:
        x (torch.Tensor): prediction (batch, seqlen, class)
        target (torch.Tensor):
            target signal masked with self.padding_id (batch, seqlen)
    Returns:
        loss (torch.Tensor) : The KL loss, scalar float value
    """
    assert x.size(2) == self.size
    batch_size = x.size(0)
    x = x.view(-1, self.size)
    target = target.view(-1)
    # use zeros_like instead of torch.no_grad() for true_dist,
    # since no_grad() can not be exported by JIT
    true_dist = torch.zeros_like(x)
    true_dist.fill_(self.smoothing / (self.size - 1))
    ignore = target == self.padding_idx  # (B,)
    total = len(target) - ignore.sum().item()
    target = target.masked_fill(ignore, 0)  # avoid -1 index
    true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
    kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
    denom = total if self.normalize_length else batch_size
    return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom

其实代码里调用了pytorch的KLDivLoss

因为Attention的输出是自回归的形式,所以和output天然可以用交叉熵损失,那么先来看看什么是交叉熵损失。

交叉熵损失 Cross Entropy Loss

这里我想说熵的含义是你能得到多少有用的信息量,它是一个很好的衡量指标,如果事情的不确定性越大,别人给你一个关于它的信息,熵的值就越大;如果这件事没什么不确定性,别人告诉你关于这个事的信息,信息量就几乎没有。

熵的公式定义如下:

Entropy,H(p)=p(i)log(p(i))Entropy, H(p)=-\sum p(i)*log(p(i))

交叉熵可以理解为真实的预测分布p和预测分布q之间的差距的函数,它的公式如下:

CrossEntropy,H(p,q)=p(i)log(q(i))Cross Entropy, H(p, q) = -\sum p(i)*log(q(i))

可以看出只有loglog里的部分变了,熵的部分是真实的分布,而交叉熵中是预测分布q,通过交叉熵可以计算出真实分布和预测分布的差距,如果预测的完全和真实分布一样,那么这里交叉熵等于普通的熵,而真实情况是几乎不可能完全预测的和真实分布完全一样,所以交叉熵会比熵大一些信息位。其中交叉熵超过熵的部分称为相对熵或者KL散度。公式如下:

(KL-Divergence是KL散度的意思,中间不是减号)

CrossEntopy=Entropy+KLDivergenceCrossEntopy=Entropy + KL-Divergence
DKL(pq)=H(p,q)H(p)=ipilog(qi)(ipilog(pi))D_{KL}(p||q)=H(p,q)-H(p)=-\sum_i p_ilog(q_i)-(-\sum_i p_ilog(p_i))
DKL(pq)=ipilog(qi)+ipilog(pi)=ipilogpiqiD_{KL}(p||q)=-\sum_ip_ilog(q_i)+\sum_ip_ilog(p_i)=\sum_i p_ilog\frac{p_i}{q_i}

WeNet中注意力部分的损失函数代码讲解

这里我们只分析KLDivLoss,至于WeNet的逐行分析我会另开一贴juejin.cn/post/709277…

KLDivLoss

对于包含N个样本的batch数据 D(x,y)D(x, y)xx是神经网络的输出,假如WeNet中是(16,13,4233 ),并且进行了归一化和对数化;yy是真实的标签(默认为概率),xxyy同维度。

nn个样本的损失值lnl_{n}计算如下:

ln=yn(logynxn)l_{n}=y_{n} \cdot\left(\log y_{n}-x_{n}\right)

class KLDivLoss(_Loss):
    __constants__ = ['reduction']
    def __init__(self, size_average=None, reduce=None, reduction='mean'):
        super(KLDivLoss, self).__init__(size_average, reduce, reduction)
    def forward(self, input, target):
        return F.kl_div(input, target, reduction=self.reduction)

pytorch中通过torch.nn.KLDivLoss类实现,也可以直接调用F.kl_div 函数,代码中的size_averagereduce已经弃用。reduction有四种取值mean,batchmean, sum, none,对应不同的返回(x,y)\ell(x, y)。 默认为mean

L={l1,,lN}L=\left\{l_{1}, \ldots, l_{N}\right\}

(x,y)={L, if reduction = ’none’ mean(L), if reduction = ’mean’ Nmean(L), if reduction = ’batchmean’ sum(L), if reduction = ’sum’ \ell(x, y)=\left\{\begin{array}{ll}L, & \text { if reduction }=\text { 'none' } \\ \operatorname{mean}(L), & \text { if reduction }=\text { 'mean' } \\ N*\operatorname {mean}(L), & \text { if reduction }=\text { 'batchmean' } \\ \operatorname{sum}(L), & \text { if reduction }=\text { 'sum' }\end{array} \right. 代码中 kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)会调用self.criterion = nn.KLDivLoss(reduction="none"),WeNet这里使用默认的,也就是每个损失都列出来。

总结

交叉熵等于熵加KL散度,交叉熵是计算真实的后验分布和预测分布之间的差距的一个函数,如果预测分布和真实分布越接近,交叉熵损失就会越小,而深度学习的目标就是降低这个损失。