开启掘金成长之旅!这是我参与「掘金日新计划 · 12 月更文挑战」的第14天
本文首发于CSDN。
4. 节点表征类
lazy initialization:
with torch.no_grad(): # Initialize lazy modules.
out = model(data.x_dict, data.edge_index_dict)
注意事项:一是检查张量的dtype(edge_index要torch.long,x要统一(一般就统一成torch.float)),二是检查edge_index没有超界(我常犯的错误是出现以节点数为索引的节点)
问题一会出现的bug可参考:AssertionError when implementing heterogenous GNN · Discussion #5175 · pyg-team/pytorch_geometric
问题二自查可参考:
for edge_type in data.edge_types:
src, _, dst = edge_type
assert data[edge_type].edge_index[0].max() < data[src].num_nodes
assert data[edge_type].edge_index[1].max() < data[dst].num_nodes
解决方式可参考我之前写的博文:RuntimeError: CUDA error: device-side assert triggered
4.1 将同质图GNN直接转换为异质图GNN
也就是直接正常定义GNN模型(有些同质图GNN无法应用于异质图),转换为异质图GNN就是在每种边类型上运行一个同质图GNN模型的实例
torch_geometric.nn.to_hetero()
torch_geometric.nn.to_hetero_with_bases()
示例代码:
import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import SAGEConv, to_hetero
dataset = OGB_MAG(root='/data/pyg_data',preprocess='metapath2vec',transform=T.ToUndirected())
data = dataset[0]
class GNN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv((-1, -1), hidden_channels)
self.conv2 = SAGEConv((-1, -1), out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
model = GNN(hidden_channels=64, out_channels=dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')
in_channels输入tuple形式,是为了二部图的信息传播(事实上我也没搞懂这是啥意思),事实上在本例中用int输入也可以:
import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import SAGEConv, to_hetero
dataset = OGB_MAG(root='/data/pyg_data',preprocess='metapath2vec',transform=T.ToUndirected())
data = dataset[0]
class GNN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv(-1, hidden_channels)
self.conv2 = SAGEConv(-1, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
model = GNN(hidden_channels=64, out_channels=dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')
如果没有转换为无向图的话,由于author节点没有入边,就会导致NotImplementedError问题。报的警告是:
env_path/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py:145: UserWarning: There exist node types ({'author'}) whose representations do not get updated during message passing as they do not occur as destination type in any edge type. This may lead to unexpected behaviour. warnings.warn(
带可学习skip-connections(就是每一层卷完的结果再加上输入的线性转换结果)的版本:
from torch_geometric.nn import GATConv, Linear, to_hetero
class GAT(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = GATConv((-1, -1), hidden_channels, add_self_loops=False)
self.lin1 = Linear(-1, hidden_channels)
self.conv2 = GATConv((-1, -1), out_channels, add_self_loops=False)
self.lin2 = Linear(-1, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index) + self.lin1(x)
x = x.relu()
x = self.conv2(x, edge_index) + self.lin2(x)
return x
model = GAT(hidden_channels=64, out_channels=dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')
可参考的训练用代码:
def train():
model.train()
optimizer.zero_grad()
out = model(data.x_dict, data.edge_index_dict)
mask = data['paper'].train_mask
loss = F.cross_entropy(out['paper'][mask], data['paper'].y[mask])
loss.backward()
optimizer.step()
return float(loss)