开启掘金成长之旅!这是我参与「掘金日新计划 · 12 月更文挑战」的第14天
本文首发于CSDN。
诸神缄默不语-个人CSDN博文目录 PyTorch Geometric (PyG) 包文档与官方代码示例学习笔记(持续更新ing...)
本文介绍使用PyG实现异质图神经网络(HGNN)的相关操作。
本文主要参考PyG文档异质图部分:Heterogeneous Graph Learning — pytorch_geometric documentation 相关官方代码示例:github.com/pyg-team/py…
注意:①很多操作可以不使用PyG的HeteroData对象就直接实现。②部分数据集无法直接通过大陆网络下载的解决方式可参考我之前写的博文:PyG (PyTorch Geometric) Dropbox系图数据集无法下载的解决方案(AMiner, DBLP, IMDB, LastFM)(持续更新ing...) ③我用的是pip安装的2.2.0版本torch-geometric,部分较早的版本可能不支持T.AddMetaPaths对象的drop_orig_edge_types属性
@[toc]
1. 示例数据集介绍
ogbn-mag异质图的schema:
共有1,939,743个节点,21,111,007条边。 数据集的原始任务是节点分类,预测paper的venue(会议或期刊)。
在PyG中的调用方法(原始数据中只有paper节点的特征,在这里是用preprocess属性增加了其他节点通过图结构获取到的特征):
from torch_geometric.datasets import OGB_MAG
dataset = OGB_MAG(root='./data', preprocess='metapath2vec')
#preprocess也可以用TransE等
data = dataset[0]
2. HeteroData对象
from torch_geometric.data import HeteroData
data = HeteroData()
data['paper'].x = ... # [num_papers, num_features_paper]
data['author'].x = ... # [num_authors, num_features_author]
data['institution'].x = ... # [num_institutions, num_features_institution]
data['field_of_study'].x = ... # [num_field, num_features_field]
data['paper', 'cites', 'paper'].edge_index = ... # [2, num_edges_cites]
data['author', 'writes', 'paper'].edge_index = ... # [2, num_edges_writes]
data['author', 'affiliated_with', 'institution'].edge_index = ... # [2, num_edges_affiliated]
data['paper', 'has_topic', 'field_of_study'].edge_index = ... # [2, num_edges_topic]
data['paper', 'cites', 'paper'].edge_attr = ... # [num_edges_cites, num_features_cites]
data['author', 'writes', 'paper'].edge_attr = ... # [num_edges_writes, num_features_writes]
data['author', 'affiliated_with', 'institution'].edge_attr = ... # [num_edges_affiliated, num_features_affiliated]
data['paper', 'has_topic', 'field_of_study'].edge_attr = ... # [num_edges_topic, num_features_topic]
节点类型用字符串切片,边类型用字符串三元组切片
data.{attribute_name}_dict提取对应的类名和值。这个可以作为GNN模型的传入项:
model = HeteroGNN(...)
output = model(data.x_dict, data.edge_index_dict, data.edge_attr_dict)
以在第一节中介绍的ogbn-mag数据为例,data对象打印出来就是这样的:
HeteroData(
paper={
x=[736389, 128],
year=[736389],
y=[736389],
train_mask=[736389],
val_mask=[736389],
test_mask=[736389]
},
author={ x=[1134649, 128] },
institution={ x=[8740, 128] },
field_of_study={ x=[59965, 128] },
(author, affiliated_with, institution)={ edge_index=[2, 1043998] },
(author, writes, paper)={ edge_index=[2, 7145660] },
(paper, cites, paper)={ edge_index=[2, 5416271] },
(paper, has_topic, field_of_study)={ edge_index=[2, 7505078] }
)