PyG (PyTorch Geometric) 异质图神经网络HGNN(4)

545 阅读6分钟

开启掘金成长之旅!这是我参与「掘金日新计划 · 12 月更文挑战」的第14天

本文首发于CSDN。

4.2 使用HeteroConv定义GNN

可以给不同的边定义不同的GNN算子

torch_geometric.nn.conv.HeteroConv 文档:pytorch-geometric.readthedocs.io/en/latest/m…

示例代码:

import torch

import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv, Linear

dataset = OGB_MAG(root='/data/pyg_data',preprocess='metapath2vec',transform=T.ToUndirected())
data = dataset[0]

class HeteroGNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_layers):
        super().__init__()

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
                ('paper', 'cites', 'paper'): GCNConv(-1, hidden_channels),
                ('author', 'writes', 'paper'): SAGEConv((-1, -1), hidden_channels),
                ('paper', 'rev_writes', 'author'): GATConv((-1, -1), hidden_channels),
            }, aggr='sum')
            self.convs.append(conv)

        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: x.relu() for key, x in x_dict.items()}
        return self.lin(x_dict['author'])

model = HeteroGNN(hidden_channels=64, out_channels=dataset.num_classes,
                  num_layers=2)

4.3 使用已有或手写的异质图算子

以HGT模型为例:

import torch

import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import HGTConv, Linear

dataset = OGB_MAG(root='/data/pyg_data',preprocess='metapath2vec',transform=T.ToUndirected())
data = dataset[0]

class HGT(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_heads, num_layers):
        super().__init__()

        self.lin_dict = torch.nn.ModuleDict()
        for node_type in data.node_types:
            self.lin_dict[node_type] = Linear(-1, hidden_channels)

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HGTConv(hidden_channels, hidden_channels, data.metadata(),
                           num_heads, group='sum')
            self.convs.append(conv)

        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        for node_type, x in x_dict.items():
            x_dict[node_type] = self.lin_dict[node_type](x).relu_()

        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)

        return self.lin(x_dict['author'])

model = HGT(hidden_channels=64, out_channels=dataset.num_classes,
            num_heads=2, num_layers=2)

5. 节点分类任务

5.1 whole-batch

5.1.1 使用已有的异质图算子

5.1.1.1 HAN

参考github.com/pyg-team/py…

from typing import Dict, List, Union

import torch
import torch.nn.functional as F
from torch import nn

import torch_geometric.transforms as T
from torch_geometric.datasets import IMDB
from torch_geometric.nn import HANConv

metapaths = [[('movie', 'actor'), ('actor', 'movie')],
             [('movie', 'director'), ('director', 'movie')]]
transform = T.AddMetaPaths(metapaths=metapaths, drop_orig_edge_types=True,
                           drop_unconnected_node_types=True)
dataset = IMDB('/data/pyg_data/IMDB', transform=transform)
data = dataset[0]
print(data)

class HAN(nn.Module):
    def __init__(self, in_channels: Union[int, Dict[str, int]],
                 out_channels: int, hidden_channels=128, heads=8):
        super().__init__()
        self.han_conv = HANConv(in_channels, hidden_channels, heads=heads,
                                dropout=0.6, metadata=data.metadata())
        self.lin = nn.Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        out = self.han_conv(x_dict, edge_index_dict)
        out = self.lin(out['movie'])
        return out

model = HAN(in_channels=-1, out_channels=3)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
data, model = data.to(device), model.to(device)

with torch.no_grad():  # Initialize lazy modules.
    out = model(data.x_dict, data.edge_index_dict)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001)

def train() -> float:
    model.train()
    optimizer.zero_grad()
    out = model(data.x_dict, data.edge_index_dict)
    mask = data['movie'].train_mask
    loss = F.cross_entropy(out[mask], data['movie'].y[mask])
    loss.backward()
    optimizer.step()
    return float(loss)

@torch.no_grad()
def test() -> List[float]:
    model.eval()
    pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1)

    accs = []
    for split in ['train_mask', 'val_mask', 'test_mask']:
        mask = data['movie'][split]
        acc = (pred[mask] == data['movie'].y[mask]).sum() / mask.sum()
        accs.append(float(acc))
    return accs

best_val_acc = 0
start_patience = patience = 100
for epoch in range(1, 200):

    loss = train()
    train_acc, val_acc, test_acc = test()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '
          f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')

    if best_val_acc <= val_acc:
        patience = start_patience
        best_val_acc = val_acc
    else:
        patience -= 1

    if patience <= 0:
        print('Stopping training as validation accuracy did not improve '
              f'for {start_patience} epochs')
        break

输出:

HeteroData(
  metapath_dict={
    (movie, metapath_0, movie)=[2],
    (movie, metapath_1, movie)=[2]
  },
  movie={
    x=[4278, 3066],
    y=[4278],
    train_mask=[4278],
    val_mask=[4278],
    test_mask=[4278]
  },
  (movie, metapath_0, movie)={ edge_index=[2, 85358] },
  (movie, metapath_1, movie)={ edge_index=[2, 17446] }
)
Epoch: 001, Loss: 1.1020, Train: 0.5125, Val: 0.4100, Test: 0.3890
Epoch: 002, Loss: 1.0783, Train: 0.5575, Val: 0.4075, Test: 0.3813
Epoch: 003, Loss: 1.0498, Train: 0.6350, Val: 0.4325, Test: 0.4112
Epoch: 004, Loss: 1.0205, Train: 0.7075, Val: 0.4850, Test: 0.4448
Epoch: 005, Loss: 0.9788, Train: 0.7375, Val: 0.5050, Test: 0.4669
Epoch: 006, Loss: 0.9410, Train: 0.7600, Val: 0.5225, Test: 0.4796
Epoch: 007, Loss: 0.8921, Train: 0.7750, Val: 0.5375, Test: 0.4937
Epoch: 008, Loss: 0.8517, Train: 0.8000, Val: 0.5475, Test: 0.5003
Epoch: 009, Loss: 0.7975, Train: 0.8175, Val: 0.5475, Test: 0.5135
Epoch: 010, Loss: 0.7488, Train: 0.8475, Val: 0.5525, Test: 0.5216
Epoch: 011, Loss: 0.7133, Train: 0.8625, Val: 0.5575, Test: 0.5308
Epoch: 012, Loss: 0.6626, Train: 0.8875, Val: 0.5700, Test: 0.5443
Epoch: 013, Loss: 0.6171, Train: 0.9050, Val: 0.5900, Test: 0.5552
Epoch: 014, Loss: 0.5769, Train: 0.9225, Val: 0.5925, Test: 0.5710
Epoch: 015, Loss: 0.5236, Train: 0.9375, Val: 0.5900, Test: 0.5785
Epoch: 016, Loss: 0.4929, Train: 0.9425, Val: 0.5925, Test: 0.5851
Epoch: 017, Loss: 0.4456, Train: 0.9375, Val: 0.5925, Test: 0.5868
Epoch: 018, Loss: 0.4266, Train: 0.9375, Val: 0.5825, Test: 0.5909
Epoch: 019, Loss: 0.3856, Train: 0.9425, Val: 0.5900, Test: 0.5926
Epoch: 020, Loss: 0.3525, Train: 0.9425, Val: 0.5900, Test: 0.5909
Epoch: 021, Loss: 0.3250, Train: 0.9450, Val: 0.5975, Test: 0.5897
Epoch: 022, Loss: 0.2900, Train: 0.9500, Val: 0.6050, Test: 0.5831
Epoch: 023, Loss: 0.2754, Train: 0.9525, Val: 0.6075, Test: 0.5825
Epoch: 024, Loss: 0.2603, Train: 0.9500, Val: 0.6075, Test: 0.5802
Epoch: 025, Loss: 0.2436, Train: 0.9500, Val: 0.6050, Test: 0.5739
Epoch: 026, Loss: 0.2251, Train: 0.9525, Val: 0.6000, Test: 0.5722
Epoch: 027, Loss: 0.2156, Train: 0.9500, Val: 0.6000, Test: 0.5733
Epoch: 028, Loss: 0.2077, Train: 0.9525, Val: 0.5950, Test: 0.5702
Epoch: 029, Loss: 0.1806, Train: 0.9550, Val: 0.5900, Test: 0.5699
Epoch: 030, Loss: 0.1942, Train: 0.9675, Val: 0.5975, Test: 0.5707
Epoch: 031, Loss: 0.1899, Train: 0.9750, Val: 0.6050, Test: 0.5693
Epoch: 032, Loss: 0.1879, Train: 0.9800, Val: 0.6050, Test: 0.5687
Epoch: 033, Loss: 0.1759, Train: 0.9825, Val: 0.6000, Test: 0.5684
Epoch: 034, Loss: 0.1706, Train: 0.9825, Val: 0.5950, Test: 0.5670
Epoch: 035, Loss: 0.1678, Train: 0.9800, Val: 0.5925, Test: 0.5656
Epoch: 036, Loss: 0.1655, Train: 0.9750, Val: 0.5950, Test: 0.5647
Epoch: 037, Loss: 0.1561, Train: 0.9750, Val: 0.6025, Test: 0.5656
Epoch: 038, Loss: 0.1588, Train: 0.9775, Val: 0.6025, Test: 0.5644
Epoch: 039, Loss: 0.1502, Train: 0.9750, Val: 0.6025, Test: 0.5644
Epoch: 040, Loss: 0.1535, Train: 0.9775, Val: 0.6000, Test: 0.5638
Epoch: 041, Loss: 0.1502, Train: 0.9800, Val: 0.6000, Test: 0.5633
Epoch: 042, Loss: 0.1638, Train: 0.9800, Val: 0.6000, Test: 0.5621
Epoch: 043, Loss: 0.1530, Train: 0.9800, Val: 0.6000, Test: 0.5624
Epoch: 044, Loss: 0.1566, Train: 0.9800, Val: 0.5975, Test: 0.5624
Epoch: 045, Loss: 0.1578, Train: 0.9800, Val: 0.6150, Test: 0.5610
Epoch: 046, Loss: 0.1441, Train: 0.9800, Val: 0.6150, Test: 0.5615
Epoch: 047, Loss: 0.1430, Train: 0.9825, Val: 0.6175, Test: 0.5604
Epoch: 048, Loss: 0.1389, Train: 0.9875, Val: 0.6150, Test: 0.5578
Epoch: 049, Loss: 0.1396, Train: 0.9875, Val: 0.6200, Test: 0.5566
Epoch: 050, Loss: 0.1547, Train: 0.9875, Val: 0.6150, Test: 0.5610
Epoch: 051, Loss: 0.1471, Train: 0.9875, Val: 0.6125, Test: 0.5644
Epoch: 052, Loss: 0.1398, Train: 0.9900, Val: 0.6150, Test: 0.5647
Epoch: 053, Loss: 0.1393, Train: 0.9875, Val: 0.6125, Test: 0.5644
Epoch: 054, Loss: 0.1542, Train: 0.9850, Val: 0.6075, Test: 0.5638
Epoch: 055, Loss: 0.1435, Train: 0.9875, Val: 0.6150, Test: 0.5627
Epoch: 056, Loss: 0.1338, Train: 0.9850, Val: 0.6225, Test: 0.5633
Epoch: 057, Loss: 0.1311, Train: 0.9875, Val: 0.6125, Test: 0.5618
Epoch: 058, Loss: 0.1353, Train: 0.9900, Val: 0.6150, Test: 0.5592
Epoch: 059, Loss: 0.1308, Train: 0.9900, Val: 0.6050, Test: 0.5581
Epoch: 060, Loss: 0.1369, Train: 0.9900, Val: 0.6100, Test: 0.5584
Epoch: 061, Loss: 0.1303, Train: 0.9900, Val: 0.6075, Test: 0.5581
Epoch: 062, Loss: 0.1279, Train: 0.9900, Val: 0.6025, Test: 0.5604
Epoch: 063, Loss: 0.1355, Train: 0.9875, Val: 0.6025, Test: 0.5621
Epoch: 064, Loss: 0.1184, Train: 0.9925, Val: 0.6075, Test: 0.5664
Epoch: 065, Loss: 0.1291, Train: 0.9925, Val: 0.6025, Test: 0.5690
Epoch: 066, Loss: 0.1242, Train: 0.9900, Val: 0.6000, Test: 0.5676
Epoch: 067, Loss: 0.1238, Train: 0.9900, Val: 0.6025, Test: 0.5670
Epoch: 068, Loss: 0.1121, Train: 0.9900, Val: 0.6025, Test: 0.5656
Epoch: 069, Loss: 0.1126, Train: 0.9900, Val: 0.6050, Test: 0.5635
Epoch: 070, Loss: 0.1208, Train: 0.9900, Val: 0.6050, Test: 0.5612
Epoch: 071, Loss: 0.1059, Train: 0.9900, Val: 0.6075, Test: 0.5589
Epoch: 072, Loss: 0.1098, Train: 0.9900, Val: 0.6025, Test: 0.5581
Epoch: 073, Loss: 0.1198, Train: 0.9950, Val: 0.5950, Test: 0.5598
Epoch: 074, Loss: 0.1214, Train: 0.9925, Val: 0.5925, Test: 0.5621
Epoch: 075, Loss: 0.1016, Train: 0.9925, Val: 0.5950, Test: 0.5601
Epoch: 076, Loss: 0.1145, Train: 0.9950, Val: 0.6000, Test: 0.5621
Epoch: 077, Loss: 0.1148, Train: 0.9950, Val: 0.6000, Test: 0.5615
Epoch: 078, Loss: 0.1135, Train: 0.9925, Val: 0.5975, Test: 0.5612
Epoch: 079, Loss: 0.1104, Train: 0.9925, Val: 0.6000, Test: 0.5624
Epoch: 080, Loss: 0.1108, Train: 0.9900, Val: 0.6050, Test: 0.5572
Epoch: 081, Loss: 0.0916, Train: 0.9900, Val: 0.6050, Test: 0.5561
Epoch: 082, Loss: 0.1275, Train: 0.9900, Val: 0.6025, Test: 0.5581
Epoch: 083, Loss: 0.0970, Train: 1.0000, Val: 0.6025, Test: 0.5607
Epoch: 084, Loss: 0.0923, Train: 1.0000, Val: 0.6025, Test: 0.5592
Epoch: 085, Loss: 0.1089, Train: 1.0000, Val: 0.6025, Test: 0.5598
Epoch: 086, Loss: 0.1032, Train: 1.0000, Val: 0.6025, Test: 0.5598
Epoch: 087, Loss: 0.0983, Train: 1.0000, Val: 0.6000, Test: 0.5615
Epoch: 088, Loss: 0.0982, Train: 1.0000, Val: 0.5950, Test: 0.5615
Epoch: 089, Loss: 0.0849, Train: 1.0000, Val: 0.5925, Test: 0.5607
Epoch: 090, Loss: 0.0982, Train: 0.9975, Val: 0.5900, Test: 0.5610
Epoch: 091, Loss: 0.1133, Train: 1.0000, Val: 0.5950, Test: 0.5650
Epoch: 092, Loss: 0.0890, Train: 1.0000, Val: 0.5950, Test: 0.5664
Epoch: 093, Loss: 0.0935, Train: 1.0000, Val: 0.6000, Test: 0.5658
Epoch: 094, Loss: 0.0935, Train: 1.0000, Val: 0.6050, Test: 0.5673
Epoch: 095, Loss: 0.1027, Train: 1.0000, Val: 0.6075, Test: 0.5681
Epoch: 096, Loss: 0.0914, Train: 0.9975, Val: 0.6000, Test: 0.5679
Epoch: 097, Loss: 0.0908, Train: 0.9975, Val: 0.5900, Test: 0.5690
Epoch: 098, Loss: 0.1003, Train: 1.0000, Val: 0.5900, Test: 0.5667
Epoch: 099, Loss: 0.0835, Train: 1.0000, Val: 0.5875, Test: 0.5670
Epoch: 100, Loss: 0.0968, Train: 1.0000, Val: 0.5900, Test: 0.5670
Epoch: 101, Loss: 0.0868, Train: 1.0000, Val: 0.5900, Test: 0.5679
Epoch: 102, Loss: 0.0906, Train: 1.0000, Val: 0.6000, Test: 0.5681
Epoch: 103, Loss: 0.0967, Train: 1.0000, Val: 0.5975, Test: 0.5681
Epoch: 104, Loss: 0.0983, Train: 1.0000, Val: 0.5925, Test: 0.5699
Epoch: 105, Loss: 0.0775, Train: 1.0000, Val: 0.5975, Test: 0.5681
Epoch: 106, Loss: 0.0840, Train: 1.0000, Val: 0.5950, Test: 0.5664
Epoch: 107, Loss: 0.0962, Train: 1.0000, Val: 0.5950, Test: 0.5633
Epoch: 108, Loss: 0.0900, Train: 1.0000, Val: 0.5950, Test: 0.5621
Epoch: 109, Loss: 0.0831, Train: 1.0000, Val: 0.5975, Test: 0.5644
Epoch: 110, Loss: 0.0844, Train: 1.0000, Val: 0.5950, Test: 0.5653
Epoch: 111, Loss: 0.1017, Train: 0.9975, Val: 0.5925, Test: 0.5667
Epoch: 112, Loss: 0.0833, Train: 0.9975, Val: 0.5950, Test: 0.5661
Epoch: 113, Loss: 0.0840, Train: 0.9975, Val: 0.5875, Test: 0.5670
Epoch: 114, Loss: 0.0809, Train: 0.9975, Val: 0.5900, Test: 0.5664
Epoch: 115, Loss: 0.0854, Train: 0.9975, Val: 0.5950, Test: 0.5673
Epoch: 116, Loss: 0.0896, Train: 0.9975, Val: 0.5975, Test: 0.5687
Epoch: 117, Loss: 0.0999, Train: 1.0000, Val: 0.5975, Test: 0.5664
Epoch: 118, Loss: 0.0890, Train: 1.0000, Val: 0.5950, Test: 0.5667
Epoch: 119, Loss: 0.0780, Train: 1.0000, Val: 0.5900, Test: 0.5658
Epoch: 120, Loss: 0.0751, Train: 1.0000, Val: 0.5875, Test: 0.5670
Epoch: 121, Loss: 0.0693, Train: 1.0000, Val: 0.5950, Test: 0.5661
Epoch: 122, Loss: 0.0822, Train: 1.0000, Val: 0.5975, Test: 0.5664
Epoch: 123, Loss: 0.0782, Train: 1.0000, Val: 0.5925, Test: 0.5635
Epoch: 124, Loss: 0.0791, Train: 1.0000, Val: 0.5950, Test: 0.5627
Epoch: 125, Loss: 0.0958, Train: 1.0000, Val: 0.6000, Test: 0.5644
Epoch: 126, Loss: 0.0764, Train: 1.0000, Val: 0.5950, Test: 0.5650
Epoch: 127, Loss: 0.0878, Train: 1.0000, Val: 0.5900, Test: 0.5650
Epoch: 128, Loss: 0.0679, Train: 1.0000, Val: 0.5900, Test: 0.5641
Epoch: 129, Loss: 0.0791, Train: 1.0000, Val: 0.5900, Test: 0.5647
Epoch: 130, Loss: 0.0809, Train: 1.0000, Val: 0.5900, Test: 0.5647
Epoch: 131, Loss: 0.0740, Train: 1.0000, Val: 0.5850, Test: 0.5661
Epoch: 132, Loss: 0.0694, Train: 1.0000, Val: 0.5825, Test: 0.5647
Epoch: 133, Loss: 0.0859, Train: 1.0000, Val: 0.5875, Test: 0.5633
Epoch: 134, Loss: 0.0833, Train: 0.9975, Val: 0.5875, Test: 0.5638
Epoch: 135, Loss: 0.0797, Train: 1.0000, Val: 0.5900, Test: 0.5656
Epoch: 136, Loss: 0.0867, Train: 1.0000, Val: 0.5950, Test: 0.5696
Epoch: 137, Loss: 0.0811, Train: 1.0000, Val: 0.5975, Test: 0.5696
Epoch: 138, Loss: 0.0710, Train: 1.0000, Val: 0.5925, Test: 0.5713
Epoch: 139, Loss: 0.0603, Train: 1.0000, Val: 0.5950, Test: 0.5722
Epoch: 140, Loss: 0.0776, Train: 1.0000, Val: 0.5925, Test: 0.5719
Epoch: 141, Loss: 0.0705, Train: 1.0000, Val: 0.5975, Test: 0.5679
Epoch: 142, Loss: 0.0775, Train: 1.0000, Val: 0.5950, Test: 0.5679
Epoch: 143, Loss: 0.0700, Train: 1.0000, Val: 0.5975, Test: 0.5696
Epoch: 144, Loss: 0.0829, Train: 1.0000, Val: 0.5975, Test: 0.5727
Epoch: 145, Loss: 0.0697, Train: 1.0000, Val: 0.6000, Test: 0.5727
Epoch: 146, Loss: 0.0697, Train: 1.0000, Val: 0.6025, Test: 0.5750
Epoch: 147, Loss: 0.0706, Train: 1.0000, Val: 0.6075, Test: 0.5727
Epoch: 148, Loss: 0.0723, Train: 1.0000, Val: 0.5975, Test: 0.5690
Epoch: 149, Loss: 0.0771, Train: 1.0000, Val: 0.5950, Test: 0.5696
Epoch: 150, Loss: 0.0650, Train: 1.0000, Val: 0.6025, Test: 0.5699
Epoch: 151, Loss: 0.0802, Train: 1.0000, Val: 0.5950, Test: 0.5676
Epoch: 152, Loss: 0.0687, Train: 1.0000, Val: 0.5925, Test: 0.5710
Epoch: 153, Loss: 0.0705, Train: 1.0000, Val: 0.5925, Test: 0.5704
Epoch: 154, Loss: 0.0831, Train: 1.0000, Val: 0.5925, Test: 0.5696
Epoch: 155, Loss: 0.0714, Train: 1.0000, Val: 0.5900, Test: 0.5690
Epoch: 156, Loss: 0.0662, Train: 1.0000, Val: 0.5850, Test: 0.5635
Stopping training as validation accuracy did not improve for 100 epochs