图解WeNet求Attention损失(逐行分析)

1,649 阅读4分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

先看今天要分析的代码:

def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """Compute loss between x and target.

    The model outputs and data labels tensors are flatten to
    (batch*seqlen, class) shape and a mask is applied to the
    padding part which should not be calculated for loss.

    Args:
        x (torch.Tensor): prediction (batch, seqlen, class)
        target (torch.Tensor):
            target signal masked with self.padding_id (batch, seqlen)
    Returns:
        loss (torch.Tensor) : The KL loss, scalar float value
    """
    assert x.size(2) == self.size
    batch_size = x.size(0)
    x = x.view(-1, self.size)
    target = target.view(-1)
    # use zeros_like instead of torch.no_grad() for true_dist,
    # since no_grad() can not be exported by JIT
    true_dist = torch.zeros_like(x)
    true_dist.fill_(self.smoothing / (self.size - 1))
    ignore = target == self.padding_idx  # (B,)
    total = len(target) - ignore.sum().item()
    target = target.masked_fill(ignore, 0)  # avoid -1 index
    true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
    kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
    denom = total if self.normalize_length else batch_size
    return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom

这段代码乍一看很吓人,其实分析后你会觉得豁然开朗,其实WeNet的代码写的真的非常好,又干净又易懂(得仔细看才会体会到易懂,我初次毛毛略略的看确实感觉很难,坚持,干就完了!)

逐行分析

(以下所有变量维度都是假设,为了方便读者理解,实际运行中不一定是这个值,本文后面不再赘述)

首先看传入的参数,一个是x,一个是target,这个很好理解,就是要求编码器的输出和最终label的差异嘛!x是经过编码器的输出,维度假设是[16, 13, 4233],其中16是batch size,13是解码帧的序列长度,一段语音时间越长,这个值就会越大,因为每段语音长短不一,所以这个值不固定,13是在当前batch中最长的值。当前batch的长度假设是[11,13,12,9,5,2,3,1,4,11,12,11,12,7,6,7],对于不足13的部分进行padding补-1.下图是一段语音的示例图,当然x是16个这么大的二维向量,这里sequence length就是例子中的13,图中黄色二维向量的维度是[13, 4233]

x.jpg target是标签向量,维度是[16, 13],其中保存了对应词典的索引。

target.jpg

下面开始看详细的代码:

  1. assert x.size(2) == self.size

    这是为了确保x第二个维度等于词表的个数,也就是4233(对于Aishell-1数据集来说)

  2. batch_size = x.size(0)

    取出batch size的长度,这里默认是16

  3. x = x.view(-1, self.size)

    这个是将x的维度从[16, 13, 4233]变成 [16*13,4233], 目的当然是为了方便计算,可以理解为一次处理batch size * sequence length这么多帧数据

  4. target = target.view(-1)

    与上同理

  5. true_dist = torch.zeros_like(x)

    这里是要一个和x同维度的label向量,别忘了参数中x的维度是[16, 13, 4233],而target的维度是[16, 13],所以后续的步骤肯定是将target变成独热编码为了x同纬度(这一步只是全部设置为0),如下图:

true_dist.jpg 到这一步,true_dist全都是0,且和x同纬度

  1. true_dist.fill_(self.smoothing / (self.size - 1))

独热编码是将label对应的位置设置为1,其他所有位置都设置为0,如上图target部分所示,而WeNet这里使用了label smoothing,将0替换为了smoothingsize1\frac{smoothing}{size-1},其中smoothing是0.1,size-1=4232.简单说一下什么是label smoothing,因为WeNet注解写的太好,我就直接copy了:

Label-smoothing loss.

在标准交叉熵损失中, 标签的数据分布为:
[0,1,2] ->
[
    [1.0, 0.0, 0.0],
    [0.0, 1.0, 0.0],
    [0.0, 0.0, 1.0],
]
而在smoothing版本的交叉熵损失中,一些概率取自真实标签probb(1.0),并在其他标签之间分配。
e.g.
smoothing=0.1
[0,1,2] ->
[
    [0.9, 0.05, 0.05],
    [0.05, 0.9, 0.05],
    [0.05, 0.05, 0.9],
]

这里填充的值是0.14232=2.36e5\frac{0.1}{4232}=2.36e-5这是一个非常小的值,所以下图用ϵ\epsilon表示2.36e52.36e-5

true_dist(1).jpg

  1. ignore = target == self.padding_idx # (B,)

    前面我们说了数据集中的句子长短不一,而我们取了当前batch中序列长度最长的那个值,其他短于这个值的帧用-1进行padding填充,所以这些-1的部分我们要忽略,不能加入求损失的过程。这里padding_idx是-1,判断target是否有值为-1的部分,如果有就忽略掉,将是否忽略保存到ignore变量中,ignore的维度这里是[16*13]

  2. total = len(target) - ignore.sum().item()

    这里是求刨开被占位的部分,真正需要计算损失的token总共有多少个,用target的总长度减去需要忽略的个数,这个例子就是146 = 16*13 - 62,也就是当前batch真正参与计算损失的字符也就146个,剩下62个都是被padding过的部分

  3. target = target.masked_fill(ignore, 0) # avoid -1 index

    这行是将忽略掉的地方(-1)全部用0来替换,因为要开始计算损失了,-1会影响结果

  4. true_dist.scatter_(1, target.unsqueeze(1), self.confidence)

    这一步是将true_dist中target对应的索引用置信度0.9替换,其中1 = confidence + smoothing

true_dist(2).jpg

  1. kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)

    这一步是计算kl损失,具体我写在了juejin.cn/post/709277…

  2. denom = total if self.normalize_length else batch_size

    这一步如果normalize_length为True则按照layer_norm的方式进行正则化,则denom为4233,如果normalize_length为False,则按照batch_norm的方式进行正则化,则denom为16.

  3. return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom

    208(16*13)个位置都计算了损失,将208个kl损失进行求和,然后除以denom,就得到了最终的损失