小白学习GNN的过程中遇到的一些函数
torch_scatter官方文档:github.com/rusty1s/pyt…
1. torch_scatter.scatter
scatter方法——通过src和index两个张量来获得一个新的张量。
torch_scatter.scatter(src: torch.Tensor, index: torch.Tensor, dim: int = - 1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, reduce: str = 'sum') → torch.Tensor
- 根据index,将index相同值对应的src元素进行对应定义的计算
- dim为在第几维进行相应的运算
具体在GIN.py中
#将dst相同值对应的edge_attr 边信息,在第0维进行sum运算
node_feat = scatter_sum(edge_attr, dst, dim=0, dim_size=x.shape[0])
e.g.scatter_sum即进行sum运算,scatter_mean即进行mean运算,scatter_max即进行max运算
关于 scatter_mean
同理scatter_max
2. torch.gather
torch.gather()是torch的一个函数
∴import torch即可
- torch.gather()函数作用:提取tensor张量中特定的元素
torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor
常用参数(三个):
- input
- dim
- index
input: 你要输入的torch.tensor()
dim: 要处理的维度,一个[ ]表示一个维度,比如[ [ 2,3 ] ]中的2和3就是在第二维,dim可以取0,1,2等;
index: 必须为torch.LongTensor()的类型,且维度大小必须和input相同,index中每一个值表示input在dim维中的下标,下标从0开始
对应代码实现:
- 讨论源Tensor维度
- 一维Tensor
- 二维Tensor(dim=0 为行聚合; dim=1 为列聚合)
- 三维Tensor(针对dim=0,1,2 三种维度聚合)
以2×3×4 的Tensor为例🌰
-
dim=2(对×4而言)和dim=1(对×3而言)
-
dim=0(对2而言)