PyG (PyTorch Geometric) 异质图神经网络HGNN(6)

191 阅读4分钟

开启掘金成长之旅!这是我参与「掘金日新计划 · 12 月更文挑战」的第14天

本文首发于CSDN。

5.1.2 使用HeteroConv

5.1.2.1 GraphSAGE

参考github.com/pyg-team/py…

(对原数据集中没有特征的节点,用[1.]作为初始特征)

import torch
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.datasets import DBLP
from torch_geometric.nn import HeteroConv, Linear, SAGEConv

# We initialize conference node features with a single one-vector as feature:
dataset = DBLP('/data/pyg_data/DBLP', transform=T.Constant(node_types='conference'))
data = dataset[0]
print(data)

class HeteroGNN(torch.nn.Module):
    def __init__(self, metadata, hidden_channels, out_channels, num_layers):
        super().__init__()

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
                edge_type: SAGEConv((-1, -1), hidden_channels)
                for edge_type in metadata[1]
            })
            self.convs.append(conv)

        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()}
        return self.lin(x_dict['author'])

model = HeteroGNN(data.metadata(), hidden_channels=64, out_channels=4,
                  num_layers=2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data, model = data.to(device), model.to(device)

with torch.no_grad():  # Initialize lazy modules.
    out = model(data.x_dict, data.edge_index_dict)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001)

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x_dict, data.edge_index_dict)
    mask = data['author'].train_mask
    loss = F.cross_entropy(out[mask], data['author'].y[mask])
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test():
    model.eval()
    pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1)

    accs = []
    for split in ['train_mask', 'val_mask', 'test_mask']:
        mask = data['author'][split]
        acc = (pred[mask] == data['author'].y[mask]).sum() / mask.sum()
        accs.append(float(acc))
    return accs

for epoch in range(1, 101):
    loss = train()
    train_acc, val_acc, test_acc = test()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '
          f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')

输出:

HeteroData(
  author={
    x=[4057, 334],
    y=[4057],
    train_mask=[4057],
    val_mask=[4057],
    test_mask=[4057]
  },
  paper={ x=[14328, 4231] },
  term={ x=[7723, 50] },
  conference={
    num_nodes=20,
    x=[20, 1]
  },
  (author, to, paper)={ edge_index=[2, 19645] },
  (paper, to, author)={ edge_index=[2, 19645] },
  (paper, to, term)={ edge_index=[2, 85810] },
  (paper, to, conference)={ edge_index=[2, 14328] },
  (term, to, paper)={ edge_index=[2, 85810] },
  (conference, to, paper)={ edge_index=[2, 14328] }
)
Epoch: 001, Loss: 1.3721, Train: 0.4550, Val: 0.3450, Test: 0.3819
Epoch: 002, Loss: 1.2867, Train: 0.6050, Val: 0.4800, Test: 0.5333
Epoch: 003, Loss: 1.1778, Train: 0.7175, Val: 0.5325, Test: 0.5941
Epoch: 004, Loss: 1.0368, Train: 0.8350, Val: 0.6050, Test: 0.6788
Epoch: 005, Loss: 0.8729, Train: 0.8950, Val: 0.6725, Test: 0.7228
Epoch: 006, Loss: 0.6991, Train: 0.9350, Val: 0.7025, Test: 0.7479
Epoch: 007, Loss: 0.5314, Train: 0.9600, Val: 0.7350, Test: 0.7765
Epoch: 008, Loss: 0.3831, Train: 0.9825, Val: 0.7525, Test: 0.8010
Epoch: 009, Loss: 0.2585, Train: 0.9900, Val: 0.7800, Test: 0.8189
Epoch: 010, Loss: 0.1632, Train: 0.9975, Val: 0.8025, Test: 0.8284
Epoch: 011, Loss: 0.0988, Train: 0.9975, Val: 0.8100, Test: 0.8293
Epoch: 012, Loss: 0.0578, Train: 1.0000, Val: 0.8025, Test: 0.8290
Epoch: 013, Loss: 0.0324, Train: 1.0000, Val: 0.8150, Test: 0.8296
Epoch: 014, Loss: 0.0178, Train: 1.0000, Val: 0.8075, Test: 0.8281
Epoch: 015, Loss: 0.0100, Train: 1.0000, Val: 0.8050, Test: 0.8281
Epoch: 016, Loss: 0.0060, Train: 1.0000, Val: 0.8050, Test: 0.8262
Epoch: 017, Loss: 0.0039, Train: 1.0000, Val: 0.8025, Test: 0.8235
Epoch: 018, Loss: 0.0027, Train: 1.0000, Val: 0.8100, Test: 0.8232
Epoch: 019, Loss: 0.0020, Train: 1.0000, Val: 0.8125, Test: 0.8235
Epoch: 020, Loss: 0.0017, Train: 1.0000, Val: 0.8125, Test: 0.8238
Epoch: 021, Loss: 0.0015, Train: 1.0000, Val: 0.8175, Test: 0.8268
Epoch: 022, Loss: 0.0014, Train: 1.0000, Val: 0.8175, Test: 0.8271
Epoch: 023, Loss: 0.0013, Train: 1.0000, Val: 0.8125, Test: 0.8274
Epoch: 024, Loss: 0.0014, Train: 1.0000, Val: 0.8150, Test: 0.8284
Epoch: 025, Loss: 0.0014, Train: 1.0000, Val: 0.8175, Test: 0.8281
Epoch: 026, Loss: 0.0015, Train: 1.0000, Val: 0.8175, Test: 0.8268
Epoch: 027, Loss: 0.0017, Train: 1.0000, Val: 0.8225, Test: 0.8225
Epoch: 028, Loss: 0.0019, Train: 1.0000, Val: 0.8250, Test: 0.8216
Epoch: 029, Loss: 0.0021, Train: 1.0000, Val: 0.8250, Test: 0.8198
Epoch: 030, Loss: 0.0024, Train: 1.0000, Val: 0.8225, Test: 0.8195
Epoch: 031, Loss: 0.0026, Train: 1.0000, Val: 0.8200, Test: 0.8192
Epoch: 032, Loss: 0.0029, Train: 1.0000, Val: 0.8225, Test: 0.8189
Epoch: 033, Loss: 0.0032, Train: 1.0000, Val: 0.8175, Test: 0.8185
Epoch: 034, Loss: 0.0035, Train: 1.0000, Val: 0.8200, Test: 0.8185
Epoch: 035, Loss: 0.0037, Train: 1.0000, Val: 0.8200, Test: 0.8176
Epoch: 036, Loss: 0.0038, Train: 1.0000, Val: 0.8225, Test: 0.8185
Epoch: 037, Loss: 0.0039, Train: 1.0000, Val: 0.8200, Test: 0.8176
Epoch: 038, Loss: 0.0041, Train: 1.0000, Val: 0.8175, Test: 0.8192
Epoch: 039, Loss: 0.0043, Train: 1.0000, Val: 0.8175, Test: 0.8204
Epoch: 040, Loss: 0.0044, Train: 1.0000, Val: 0.8150, Test: 0.8189
Epoch: 041, Loss: 0.0045, Train: 1.0000, Val: 0.8150, Test: 0.8173
Epoch: 042, Loss: 0.0046, Train: 1.0000, Val: 0.8175, Test: 0.8179
Epoch: 043, Loss: 0.0047, Train: 1.0000, Val: 0.8150, Test: 0.8170
Epoch: 044, Loss: 0.0047, Train: 1.0000, Val: 0.8175, Test: 0.8185
Epoch: 045, Loss: 0.0047, Train: 1.0000, Val: 0.8125, Test: 0.8195
Epoch: 046, Loss: 0.0047, Train: 1.0000, Val: 0.8150, Test: 0.8192
Epoch: 047, Loss: 0.0047, Train: 1.0000, Val: 0.8125, Test: 0.8182
Epoch: 048, Loss: 0.0047, Train: 1.0000, Val: 0.8075, Test: 0.8167
Epoch: 049, Loss: 0.0047, Train: 1.0000, Val: 0.8050, Test: 0.8158
Epoch: 050, Loss: 0.0047, Train: 1.0000, Val: 0.8000, Test: 0.8167
Epoch: 051, Loss: 0.0047, Train: 1.0000, Val: 0.8050, Test: 0.8170
Epoch: 052, Loss: 0.0047, Train: 1.0000, Val: 0.8100, Test: 0.8152
Epoch: 053, Loss: 0.0046, Train: 1.0000, Val: 0.8075, Test: 0.8149
Epoch: 054, Loss: 0.0046, Train: 1.0000, Val: 0.8075, Test: 0.8133
Epoch: 055, Loss: 0.0046, Train: 1.0000, Val: 0.8100, Test: 0.8139
Epoch: 056, Loss: 0.0046, Train: 1.0000, Val: 0.8100, Test: 0.8152
Epoch: 057, Loss: 0.0045, Train: 1.0000, Val: 0.8050, Test: 0.8149
Epoch: 058, Loss: 0.0045, Train: 1.0000, Val: 0.8025, Test: 0.8146
Epoch: 059, Loss: 0.0045, Train: 1.0000, Val: 0.8100, Test: 0.8139
Epoch: 060, Loss: 0.0044, Train: 1.0000, Val: 0.8125, Test: 0.8149
Epoch: 061, Loss: 0.0044, Train: 1.0000, Val: 0.8100, Test: 0.8149
Epoch: 062, Loss: 0.0043, Train: 1.0000, Val: 0.8025, Test: 0.8130
Epoch: 063, Loss: 0.0043, Train: 1.0000, Val: 0.8050, Test: 0.8124
Epoch: 064, Loss: 0.0042, Train: 1.0000, Val: 0.8050, Test: 0.8124
Epoch: 065, Loss: 0.0042, Train: 1.0000, Val: 0.8100, Test: 0.8127
Epoch: 066, Loss: 0.0041, Train: 1.0000, Val: 0.8100, Test: 0.8130
Epoch: 067, Loss: 0.0041, Train: 1.0000, Val: 0.8050, Test: 0.8121
Epoch: 068, Loss: 0.0040, Train: 1.0000, Val: 0.8050, Test: 0.8124
Epoch: 069, Loss: 0.0040, Train: 1.0000, Val: 0.8100, Test: 0.8118
Epoch: 070, Loss: 0.0039, Train: 1.0000, Val: 0.8075, Test: 0.8115
Epoch: 071, Loss: 0.0038, Train: 1.0000, Val: 0.8050, Test: 0.8109
Epoch: 072, Loss: 0.0038, Train: 1.0000, Val: 0.8075, Test: 0.8109
Epoch: 073, Loss: 0.0037, Train: 1.0000, Val: 0.8100, Test: 0.8099
Epoch: 074, Loss: 0.0037, Train: 1.0000, Val: 0.8100, Test: 0.8109
Epoch: 075, Loss: 0.0036, Train: 1.0000, Val: 0.8050, Test: 0.8106
Epoch: 076, Loss: 0.0036, Train: 1.0000, Val: 0.8075, Test: 0.8115
Epoch: 077, Loss: 0.0036, Train: 1.0000, Val: 0.8100, Test: 0.8112
Epoch: 078, Loss: 0.0035, Train: 1.0000, Val: 0.8075, Test: 0.8112
Epoch: 079, Loss: 0.0035, Train: 1.0000, Val: 0.8050, Test: 0.8112
Epoch: 080, Loss: 0.0035, Train: 1.0000, Val: 0.8050, Test: 0.8112
Epoch: 081, Loss: 0.0034, Train: 1.0000, Val: 0.8050, Test: 0.8115
Epoch: 082, Loss: 0.0034, Train: 1.0000, Val: 0.8050, Test: 0.8109
Epoch: 083, Loss: 0.0034, Train: 1.0000, Val: 0.8050, Test: 0.8109
Epoch: 084, Loss: 0.0034, Train: 1.0000, Val: 0.8050, Test: 0.8112
Epoch: 085, Loss: 0.0033, Train: 1.0000, Val: 0.8050, Test: 0.8106
Epoch: 086, Loss: 0.0033, Train: 1.0000, Val: 0.8025, Test: 0.8103
Epoch: 087, Loss: 0.0033, Train: 1.0000, Val: 0.8025, Test: 0.8103
Epoch: 088, Loss: 0.0032, Train: 1.0000, Val: 0.8025, Test: 0.8099
Epoch: 089, Loss: 0.0032, Train: 1.0000, Val: 0.8025, Test: 0.8106
Epoch: 090, Loss: 0.0032, Train: 1.0000, Val: 0.8025, Test: 0.8103
Epoch: 091, Loss: 0.0032, Train: 1.0000, Val: 0.8025, Test: 0.8106
Epoch: 092, Loss: 0.0031, Train: 1.0000, Val: 0.8025, Test: 0.8106
Epoch: 093, Loss: 0.0031, Train: 1.0000, Val: 0.8025, Test: 0.8109
Epoch: 094, Loss: 0.0031, Train: 1.0000, Val: 0.8025, Test: 0.8115
Epoch: 095, Loss: 0.0031, Train: 1.0000, Val: 0.8000, Test: 0.8121
Epoch: 096, Loss: 0.0030, Train: 1.0000, Val: 0.8025, Test: 0.8124
Epoch: 097, Loss: 0.0030, Train: 1.0000, Val: 0.8000, Test: 0.8130
Epoch: 098, Loss: 0.0030, Train: 1.0000, Val: 0.8025, Test: 0.8127
Epoch: 099, Loss: 0.0030, Train: 1.0000, Val: 0.7975, Test: 0.8124
Epoch: 100, Loss: 0.0030, Train: 1.0000, Val: 0.8000, Test: 0.8127