PyG (PyTorch Geometric) 异质图神经网络HGNN(5)

400 阅读4分钟

开启掘金成长之旅!这是我参与「掘金日新计划 · 12 月更文挑战」的第14天

本文首发于CSDN。

5.1.1.2 HGT

参考github.com/pyg-team/py…

示例代码:

import torch
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.datasets import DBLP
from torch_geometric.nn import HGTConv, Linear

dataset = DBLP('/data/pyg_data/DBLP', transform=T.Constant(node_types='conference'))
data = dataset[0]
print(data)

class HGT(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_heads, num_layers):
        super().__init__()

        self.lin_dict = torch.nn.ModuleDict()
        for node_type in data.node_types:
            self.lin_dict[node_type] = Linear(-1, hidden_channels)

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HGTConv(hidden_channels, hidden_channels, data.metadata(),
                           num_heads, group='sum')
            self.convs.append(conv)

        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        x_dict = {
            node_type: self.lin_dict[node_type](x).relu_()
            for node_type, x in x_dict.items()
        }

        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)

        return self.lin(x_dict['author'])

model = HGT(hidden_channels=64, out_channels=4, num_heads=2, num_layers=1)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
data, model = data.to(device), model.to(device)

with torch.no_grad():  # Initialize lazy modules.
    out = model(data.x_dict, data.edge_index_dict)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001)

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x_dict, data.edge_index_dict)
    mask = data['author'].train_mask
    loss = F.cross_entropy(out[mask], data['author'].y[mask])
    loss.backward()
    optimizer.step()
    return float(loss)

@torch.no_grad()
def test():
    model.eval()
    pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1)

    accs = []
    for split in ['train_mask', 'val_mask', 'test_mask']:
        mask = data['author'][split]
        acc = (pred[mask] == data['author'].y[mask]).sum() / mask.sum()
        accs.append(float(acc))
    return accs

for epoch in range(1, 101):
    loss = train()
    train_acc, val_acc, test_acc = test()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '
          f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')

输出:

HeteroData(
  author={
    x=[4057, 334],
    y=[4057],
    train_mask=[4057],
    val_mask=[4057],
    test_mask=[4057]
  },
  paper={ x=[14328, 4231] },
  term={ x=[7723, 50] },
  conference={
    num_nodes=20,
    x=[20, 1]
  },
  (author, to, paper)={ edge_index=[2, 19645] },
  (paper, to, author)={ edge_index=[2, 19645] },
  (paper, to, term)={ edge_index=[2, 85810] },
  (paper, to, conference)={ edge_index=[2, 14328] },
  (term, to, paper)={ edge_index=[2, 85810] },
  (conference, to, paper)={ edge_index=[2, 14328] }
)
Epoch: 001, Loss: 1.3967, Train: 0.2550, Val: 0.2700, Test: 0.2539
Epoch: 002, Loss: 1.3708, Train: 0.6750, Val: 0.4825, Test: 0.5272
Epoch: 003, Loss: 1.3459, Train: 0.6200, Val: 0.4525, Test: 0.5302
Epoch: 004, Loss: 1.3173, Train: 0.5675, Val: 0.4300, Test: 0.4992
Epoch: 005, Loss: 1.2809, Train: 0.5550, Val: 0.4150, Test: 0.4836
Epoch: 006, Loss: 1.2323, Train: 0.5825, Val: 0.4175, Test: 0.4814
Epoch: 007, Loss: 1.1662, Train: 0.6575, Val: 0.4325, Test: 0.5026
Epoch: 008, Loss: 1.0761, Train: 0.7475, Val: 0.4775, Test: 0.5511
Epoch: 009, Loss: 0.9564, Train: 0.8425, Val: 0.5575, Test: 0.6098
Epoch: 010, Loss: 0.8064, Train: 0.9400, Val: 0.6275, Test: 0.6831
Epoch: 011, Loss: 0.6348, Train: 0.9750, Val: 0.7125, Test: 0.7446
Epoch: 012, Loss: 0.4627, Train: 0.9950, Val: 0.7225, Test: 0.7768
Epoch: 013, Loss: 0.3151, Train: 0.9975, Val: 0.7275, Test: 0.7842
Epoch: 014, Loss: 0.2002, Train: 0.9975, Val: 0.7200, Test: 0.7835
Epoch: 015, Loss: 0.1142, Train: 0.9950, Val: 0.7200, Test: 0.7706
Epoch: 016, Loss: 0.0614, Train: 0.9950, Val: 0.7250, Test: 0.7596
Epoch: 017, Loss: 0.0336, Train: 1.0000, Val: 0.7125, Test: 0.7633
Epoch: 018, Loss: 0.0163, Train: 1.0000, Val: 0.6950, Test: 0.7676
Epoch: 019, Loss: 0.0068, Train: 1.0000, Val: 0.7150, Test: 0.7688
Epoch: 020, Loss: 0.0033, Train: 1.0000, Val: 0.7175, Test: 0.7630
Epoch: 021, Loss: 0.0022, Train: 1.0000, Val: 0.7200, Test: 0.7630
Epoch: 022, Loss: 0.0016, Train: 1.0000, Val: 0.7175, Test: 0.7587
Epoch: 023, Loss: 0.0010, Train: 1.0000, Val: 0.7175, Test: 0.7624
Epoch: 024, Loss: 0.0006, Train: 1.0000, Val: 0.7300, Test: 0.7621
Epoch: 025, Loss: 0.0003, Train: 1.0000, Val: 0.7300, Test: 0.7599
Epoch: 026, Loss: 0.0002, Train: 1.0000, Val: 0.7325, Test: 0.7571
Epoch: 027, Loss: 0.0002, Train: 1.0000, Val: 0.7425, Test: 0.7608
Epoch: 028, Loss: 0.0002, Train: 1.0000, Val: 0.7425, Test: 0.7633
Epoch: 029, Loss: 0.0002, Train: 1.0000, Val: 0.7475, Test: 0.7624
Epoch: 030, Loss: 0.0003, Train: 1.0000, Val: 0.7525, Test: 0.7636
Epoch: 031, Loss: 0.0005, Train: 1.0000, Val: 0.7475, Test: 0.7667
Epoch: 032, Loss: 0.0006, Train: 1.0000, Val: 0.7475, Test: 0.7663
Epoch: 033, Loss: 0.0007, Train: 1.0000, Val: 0.7550, Test: 0.7673
Epoch: 034, Loss: 0.0008, Train: 1.0000, Val: 0.7600, Test: 0.7682
Epoch: 035, Loss: 0.0008, Train: 1.0000, Val: 0.7625, Test: 0.7703
Epoch: 036, Loss: 0.0008, Train: 1.0000, Val: 0.7650, Test: 0.7722
Epoch: 037, Loss: 0.0008, Train: 1.0000, Val: 0.7650, Test: 0.7759
Epoch: 038, Loss: 0.0008, Train: 1.0000, Val: 0.7625, Test: 0.7722
Epoch: 039, Loss: 0.0009, Train: 1.0000, Val: 0.7550, Test: 0.7756
Epoch: 040, Loss: 0.0010, Train: 1.0000, Val: 0.7550, Test: 0.7734
Epoch: 041, Loss: 0.0010, Train: 1.0000, Val: 0.7525, Test: 0.7749
Epoch: 042, Loss: 0.0011, Train: 1.0000, Val: 0.7475, Test: 0.7743
Epoch: 043, Loss: 0.0011, Train: 1.0000, Val: 0.7500, Test: 0.7753
Epoch: 044, Loss: 0.0011, Train: 1.0000, Val: 0.7525, Test: 0.7746
Epoch: 045, Loss: 0.0012, Train: 1.0000, Val: 0.7500, Test: 0.7749
Epoch: 046, Loss: 0.0012, Train: 1.0000, Val: 0.7550, Test: 0.7762
Epoch: 047, Loss: 0.0013, Train: 1.0000, Val: 0.7575, Test: 0.7792
Epoch: 048, Loss: 0.0015, Train: 1.0000, Val: 0.7550, Test: 0.7808
Epoch: 049, Loss: 0.0016, Train: 1.0000, Val: 0.7525, Test: 0.7783
Epoch: 050, Loss: 0.0016, Train: 1.0000, Val: 0.7575, Test: 0.7808
Epoch: 051, Loss: 0.0016, Train: 1.0000, Val: 0.7600, Test: 0.7811
Epoch: 052, Loss: 0.0016, Train: 1.0000, Val: 0.7625, Test: 0.7842
Epoch: 053, Loss: 0.0017, Train: 1.0000, Val: 0.7600, Test: 0.7823
Epoch: 054, Loss: 0.0018, Train: 1.0000, Val: 0.7600, Test: 0.7835
Epoch: 055, Loss: 0.0019, Train: 1.0000, Val: 0.7600, Test: 0.7808
Epoch: 056, Loss: 0.0019, Train: 1.0000, Val: 0.7575, Test: 0.7820
Epoch: 057, Loss: 0.0019, Train: 1.0000, Val: 0.7600, Test: 0.7832
Epoch: 058, Loss: 0.0020, Train: 1.0000, Val: 0.7625, Test: 0.7848
Epoch: 059, Loss: 0.0021, Train: 1.0000, Val: 0.7625, Test: 0.7845
Epoch: 060, Loss: 0.0021, Train: 1.0000, Val: 0.7625, Test: 0.7839
Epoch: 061, Loss: 0.0022, Train: 1.0000, Val: 0.7650, Test: 0.7826
Epoch: 062, Loss: 0.0023, Train: 1.0000, Val: 0.7700, Test: 0.7826
Epoch: 063, Loss: 0.0023, Train: 1.0000, Val: 0.7700, Test: 0.7848
Epoch: 064, Loss: 0.0024, Train: 1.0000, Val: 0.7700, Test: 0.7820
Epoch: 065, Loss: 0.0025, Train: 1.0000, Val: 0.7700, Test: 0.7839
Epoch: 066, Loss: 0.0025, Train: 1.0000, Val: 0.7675, Test: 0.7826
Epoch: 067, Loss: 0.0026, Train: 1.0000, Val: 0.7675, Test: 0.7832
Epoch: 068, Loss: 0.0026, Train: 1.0000, Val: 0.7650, Test: 0.7854
Epoch: 069, Loss: 0.0027, Train: 1.0000, Val: 0.7650, Test: 0.7863
Epoch: 070, Loss: 0.0027, Train: 1.0000, Val: 0.7625, Test: 0.7866
Epoch: 071, Loss: 0.0028, Train: 1.0000, Val: 0.7625, Test: 0.7860
Epoch: 072, Loss: 0.0028, Train: 1.0000, Val: 0.7625, Test: 0.7872
Epoch: 073, Loss: 0.0028, Train: 1.0000, Val: 0.7625, Test: 0.7872
Epoch: 074, Loss: 0.0028, Train: 1.0000, Val: 0.7625, Test: 0.7860
Epoch: 075, Loss: 0.0029, Train: 1.0000, Val: 0.7625, Test: 0.7854
Epoch: 076, Loss: 0.0029, Train: 1.0000, Val: 0.7650, Test: 0.7863
Epoch: 077, Loss: 0.0029, Train: 1.0000, Val: 0.7600, Test: 0.7866
Epoch: 078, Loss: 0.0029, Train: 1.0000, Val: 0.7625, Test: 0.7875
Epoch: 079, Loss: 0.0030, Train: 1.0000, Val: 0.7625, Test: 0.7872
Epoch: 080, Loss: 0.0030, Train: 1.0000, Val: 0.7625, Test: 0.7885
Epoch: 081, Loss: 0.0030, Train: 1.0000, Val: 0.7625, Test: 0.7897
Epoch: 082, Loss: 0.0030, Train: 1.0000, Val: 0.7600, Test: 0.7894
Epoch: 083, Loss: 0.0030, Train: 1.0000, Val: 0.7600, Test: 0.7897
Epoch: 084, Loss: 0.0030, Train: 1.0000, Val: 0.7625, Test: 0.7900
Epoch: 085, Loss: 0.0030, Train: 1.0000, Val: 0.7625, Test: 0.7897
Epoch: 086, Loss: 0.0031, Train: 1.0000, Val: 0.7650, Test: 0.7903
Epoch: 087, Loss: 0.0031, Train: 1.0000, Val: 0.7650, Test: 0.7906
Epoch: 088, Loss: 0.0031, Train: 1.0000, Val: 0.7650, Test: 0.7909
Epoch: 089, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7915
Epoch: 090, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7915
Epoch: 091, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7915
Epoch: 092, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7912
Epoch: 093, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7909
Epoch: 094, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7909
Epoch: 095, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7909
Epoch: 096, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7909
Epoch: 097, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7906
Epoch: 098, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7912
Epoch: 099, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7912
Epoch: 100, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7915