背景:
sequence mask是为了使得decoder不能看见未来的信息。也就是对于一个序列,在time_step为t的时刻,我们的解码输出应该只能依赖于t时刻之前的输出,而不能依赖t之后的输出。
做法:
用三角形mask 将矩阵上半三角全部置0,包括对角线。表示这些位置的信息看不见。
代码:
import torch
import torch.nn as nn
import numpy as np
def mask_(matrices, maskval=0.0, mask_diagonal=True):
"""
Masks out all values in the given batch of matrices where i <= j holds,
i < j if mask_diagonal is false
In place operation
:param tns:
:return:
"""
b, h, w = matrices.size()
indices = torch.triu_indices(h,w, offset=0 if mask_diagonal else 1)
matrices[:, indices[0], indices[1]] = maskval
print("hi")
matrices = torch.randn(3, 4, 4)
print(matrices)
mask_(matrices, maskval=0.0, mask_diagonal=True)
print(matrices)
注意:
mask_ 函数先读入matrices矩阵,获得大小。用triu_indices 获得需要置为0的位置。这里的indices为:
最后原始矩阵为:
mask 后的矩阵为: