transformer中三种mask的使用

1,847 阅读1分钟

transformer中共使用三个mask:

  • src_mask: 用于encoder中当句子长度不一时,需要将所有的句子填充至相同的长度。因此在求Q和K的相关性时,由于Q和K在encoder中相等,所以src_mask最后表现为右边和下边填充1的矩阵;
  • target_mask: 用于使decoder中前面词无法使用到后面词的信息,并且也需要考虑padding。在训练decoder时,为了保证并行,transformer是一次性输入正确的target的,而并非auto-regressive那种时序的。因此在求Q和K的相关性时,由于Q与K相等,前面的query不应该和后面的key产生相关性,因此需要把当前query对于它之后的key的相关性置为0,因此就需要mask掉,最后产生一个类似上三角矩阵的mask。考虑到decoder长度也不一致,因此target_mask表现为padding mask与之前mask的一个并集(或运算)。
  • memory_mask:用于decoder中的cross attention中,主要是为了综合encoder和decoder中的padding。cross attention中的Q来自decoder,需要和encoder中的key-value sets求相关性矩阵,这里就不涉及未来信息的问题了,只需要考虑padding。因此最后所产生memory_mask下方有多少行1取决于decoder的padding,右方有多少列1取决于encoder的padding。

具体mask可视化参考:【深度学习】Transformer中的mask机制超详细讲解_Articoder的博客-CSDN博客