秒懂 torch.scatter_ 函数的意义(详解)

232 阅读2分钟

Tensor.scatter_(dim, index, src, reduce=None) → Tensor

Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.

官网的说明言简意赅。读起来好像理解了,但是又好像没理解。对照着例子看,又是一头雾水。等我们理解函数的语义后,再来理解这段官网的说明。

下面这个官网语义说明,要比上面的文字说明更容易理解些:

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

我们先引用官网的一个二维张量例子

import torch

src = torch.arange(1, 11).reshape((2, 5))
# src 张量为
# tensor([[ 1,  2,  3,  4,  5],
#        [ 6,  7,  8,  9, 10]])
index = torch.tensor([[0, 1, 2, 0]])
dst = torch.zeros(3, 5, dtype=src.dtype)
dst.scatter_(0, index, src)
print(dst)
# 结果张量为
# tensor([[1, 0, 0, 4, 0],
#        [0, 2, 0, 0, 0],
#        [0, 0, 3, 0, 0]])

可以通过循环遍历的方式把语义说明的更简单清晰,如下:

import torch

src = torch.arange(1, 11).reshape((2, 5))
index = torch.tensor([[0, 1, 2, 0]])
dst = torch.zeros(3, 5, dtype=src.dtype)
# index.shape 是 torch.Size([1, 4]), index.shape[0] 为 1,index.shape[1]为 4
# 遍历张量
for i in range(index.shape[0]):
    for j in range(index.shape[1]):
        idx = index[i][j]
        print("%d" % idx.item())
        # scatter_ 函数第一个参数 dim 指定维度的索引用 idx 代替,这里是 dim=0,如果是 dim=1,那么应该是dst[i][idx] = src[i][j]
        dst[idx][j] = src[i][j]

print(dst)
# tensor([[1, 0, 0, 4, 0],
#        [0, 2, 0, 0, 0],
#        [0, 0, 3, 0, 0]])

看到这里,我们对scatter_ 的语义应该就完全理解了。我们来翻译一下官网的说明:

将所有 src 张量的所有值写到目标张量(self) 中,写入的位置由 index 张量指定。对于src中的每个值,输出索引由两部分组成:

  • 不等于 dim 的维度,取src维度索引;
  • 等于 dim 的维度由 index 中对应的值指定。

理解了吗?还是觉得很拗口。