Tensor.scatter_(dim, index, src, reduce=None) → Tensor
Writes all values from the tensor src
into self
at the indices specified in the index
tensor. For each value in src
, its output index is specified by its index in src
for dimension != dim
and by the corresponding value in index
for dimension = dim
.
官网的说明言简意赅。读起来好像理解了,但是又好像没理解。对照着例子看,又是一头雾水。等我们理解函数的语义后,再来理解这段官网的说明。
下面这个官网语义说明,要比上面的文字说明更容易理解些:
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
我们先引用官网的一个二维张量例子
import torch
src = torch.arange(1, 11).reshape((2, 5))
# src 张量为
# tensor([[ 1, 2, 3, 4, 5],
# [ 6, 7, 8, 9, 10]])
index = torch.tensor([[0, 1, 2, 0]])
dst = torch.zeros(3, 5, dtype=src.dtype)
dst.scatter_(0, index, src)
print(dst)
# 结果张量为
# tensor([[1, 0, 0, 4, 0],
# [0, 2, 0, 0, 0],
# [0, 0, 3, 0, 0]])
可以通过循环遍历的方式把语义说明的更简单清晰,如下:
import torch
src = torch.arange(1, 11).reshape((2, 5))
index = torch.tensor([[0, 1, 2, 0]])
dst = torch.zeros(3, 5, dtype=src.dtype)
# index.shape 是 torch.Size([1, 4]), index.shape[0] 为 1,index.shape[1]为 4
# 遍历张量
for i in range(index.shape[0]):
for j in range(index.shape[1]):
idx = index[i][j]
print("%d" % idx.item())
# scatter_ 函数第一个参数 dim 指定维度的索引用 idx 代替,这里是 dim=0,如果是 dim=1,那么应该是dst[i][idx] = src[i][j]
dst[idx][j] = src[i][j]
print(dst)
# tensor([[1, 0, 0, 4, 0],
# [0, 2, 0, 0, 0],
# [0, 0, 3, 0, 0]])
看到这里,我们对scatter_ 的语义应该就完全理解了。我们来翻译一下官网的说明:
将所有 src
张量的所有值写到目标张量(self) 中,写入的位置由 index
张量指定。对于src
中的每个值,输出索引由两部分组成:
不等于 dim 的维度
,取src
维度索引;等于 dim 的维度
由 index 中对应的值指定。
理解了吗?还是觉得很拗口。