Pytorch之torch_scatter.scatter,torch.gather等函数的学习

1,191 阅读1分钟

小白学习GNN的过程中遇到的一些函数

torch_scatter官方文档:github.com/rusty1s/pyt…

1. torch_scatter.scatter

scatter方法——通过src和index两个张量来获得一个新的张量。

torch_scatter.scatter(src: torch.Tensor, index: torch.Tensor, dim: int = - 1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, reduce: str = 'sum') → torch.Tensor
  • 根据index,将index相同值对应的src元素进行对应定义的计算
  • dim为在第几维进行相应的运算

image.png

image.png

具体在GIN.py中

#将dst相同值对应的edge_attr 边信息,在第0维进行sum运算
node_feat = scatter_sum(edge_attr, dst, dim=0, dim_size=x.shape[0])

e.g.scatter_sum即进行sum运算,scatter_mean即进行mean运算,scatter_max即进行max运算

关于 scatter_mean image.png

image.png


同理scatter_max

image.png

2. torch.gather

torch.gather()是torch的一个函数

∴import torch即可

  • torch.gather()函数作用:提取tensor张量中特定的元素
torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor

常用参数(三个):

  • input
  • dim
  • index
  • input: 你要输入的torch.tensor()

  • dim: 要处理的维度,一个[ ]表示一个维度,比如[ [ 2,3 ] ]中的2和3就是在第二维,dim可以取0,1,2等;

  • index: 必须为torch.LongTensor()的类型,且维度大小必须和input相同,index中每一个值表示input在dim维中的下标,下标从0开始

image.png

对应代码实现:

  • 讨论源Tensor维度
  1. 一维Tensor

image.png

  1. 二维Tensor(dim=0 为行聚合; dim=1 为列聚合

image.png

image.png

  1. 三维Tensor(针对dim=0,1,2 三种维度聚合)

以2×3×4 的Tensor为例🌰

  • dim=2(对×4而言)和dim=1(对×3而言) image.png

  • dim=0(对2而言) image.png