PyG (PyTorch Geometric) 异质图神经网络HGNN(3)

592 阅读1分钟

开启掘金成长之旅!这是我参与「掘金日新计划 · 12 月更文挑战」的第14天

本文首发于CSDN。

4. 节点表征类

lazy initialization:

with torch.no_grad():  # Initialize lazy modules.
    out = model(data.x_dict, data.edge_index_dict)

注意事项:一是检查张量的dtype(edge_indextorch.longx要统一(一般就统一成torch.float)),二是检查edge_index没有超界(我常犯的错误是出现以节点数为索引的节点)

问题一会出现的bug可参考:AssertionError when implementing heterogenous GNN · Discussion #5175 · pyg-team/pytorch_geometric

问题二自查可参考:

for edge_type in data.edge_types:
    src, _, dst = edge_type
    assert data[edge_type].edge_index[0].max() < data[src].num_nodes
    assert data[edge_type].edge_index[1].max() < data[dst].num_nodes

解决方式可参考我之前写的博文:RuntimeError: CUDA error: device-side assert triggered

4.1 将同质图GNN直接转换为异质图GNN

也就是直接正常定义GNN模型(有些同质图GNN无法应用于异质图),转换为异质图GNN就是在每种边类型上运行一个同质图GNN模型的实例

torch_geometric.nn.to_hetero() torch_geometric.nn.to_hetero_with_bases()

示例代码:

import torch

import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import SAGEConv, to_hetero

dataset = OGB_MAG(root='/data/pyg_data',preprocess='metapath2vec',transform=T.ToUndirected())
data = dataset[0]

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

model = GNN(hidden_channels=64, out_channels=dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')

在这里插入图片描述

in_channels输入tuple形式,是为了二部图的信息传播(事实上我也没搞懂这是啥意思),事实上在本例中用int输入也可以:

import torch

import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import SAGEConv, to_hetero

dataset = OGB_MAG(root='/data/pyg_data',preprocess='metapath2vec',transform=T.ToUndirected())
data = dataset[0]

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(-1, hidden_channels)
        self.conv2 = SAGEConv(-1, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

model = GNN(hidden_channels=64, out_channels=dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')

如果没有转换为无向图的话,由于author节点没有入边,就会导致NotImplementedError问题。报的警告是: env_path/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py:145: UserWarning: There exist node types ({'author'}) whose representations do not get updated during message passing as they do not occur as destination type in any edge type. This may lead to unexpected behaviour. warnings.warn(

带可学习skip-connections(就是每一层卷完的结果再加上输入的线性转换结果)的版本:

from torch_geometric.nn import GATConv, Linear, to_hetero

class GAT(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GATConv((-1, -1), hidden_channels, add_self_loops=False)
        self.lin1 = Linear(-1, hidden_channels)
        self.conv2 = GATConv((-1, -1), out_channels, add_self_loops=False)
        self.lin2 = Linear(-1, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index) + self.lin1(x)
        x = x.relu()
        x = self.conv2(x, edge_index) + self.lin2(x)
        return x

model = GAT(hidden_channels=64, out_channels=dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')

可参考的训练用代码:

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x_dict, data.edge_index_dict)
    mask = data['paper'].train_mask
    loss = F.cross_entropy(out['paper'][mask], data['paper'].y[mask])
    loss.backward()
    optimizer.step()
    return float(loss)