【NLP】transformer 中的mask 计算

178 阅读1分钟

背景:

image.png sequence mask是为了使得decoder不能看见未来的信息。也就是对于一个序列,在time_step为t的时刻,我们的解码输出应该只能依赖于t时刻之前的输出,而不能依赖t之后的输出。

做法:

用三角形mask 将矩阵上半三角全部置0,包括对角线。表示这些位置的信息看不见。

代码:

import torch
import torch.nn as nn
import numpy as np

def mask_(matrices, maskval=0.0, mask_diagonal=True):
    """
    Masks out all values in the given batch of matrices where i <= j holds,
    i < j if mask_diagonal is false
    In place operation
    :param tns:
    :return:
    """
    b, h, w = matrices.size()

    indices = torch.triu_indices(h,w, offset=0 if mask_diagonal else 1)
    matrices[:, indices[0], indices[1]] = maskval
    print("hi")

matrices = torch.randn(3, 4, 4)
print(matrices)
mask_(matrices, maskval=0.0, mask_diagonal=True)
print(matrices)

注意:

mask_ 函数先读入matrices矩阵,获得大小。用triu_indices 获得需要置为0的位置。这里的indices为:

image.png

最后原始矩阵为:

image.png

mask 后的矩阵为:

image.png