开启掘金成长之旅!这是我参与「掘金日新计划 · 12 月更文挑战」的第14天
本文首发于CSDN。
5.1.1.2 HGT
示例代码:
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.datasets import DBLP
from torch_geometric.nn import HGTConv, Linear
dataset = DBLP('/data/pyg_data/DBLP', transform=T.Constant(node_types='conference'))
data = dataset[0]
print(data)
class HGT(torch.nn.Module):
def __init__(self, hidden_channels, out_channels, num_heads, num_layers):
super().__init__()
self.lin_dict = torch.nn.ModuleDict()
for node_type in data.node_types:
self.lin_dict[node_type] = Linear(-1, hidden_channels)
self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
conv = HGTConv(hidden_channels, hidden_channels, data.metadata(),
num_heads, group='sum')
self.convs.append(conv)
self.lin = Linear(hidden_channels, out_channels)
def forward(self, x_dict, edge_index_dict):
x_dict = {
node_type: self.lin_dict[node_type](x).relu_()
for node_type, x in x_dict.items()
}
for conv in self.convs:
x_dict = conv(x_dict, edge_index_dict)
return self.lin(x_dict['author'])
model = HGT(hidden_channels=64, out_channels=4, num_heads=2, num_layers=1)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
data, model = data.to(device), model.to(device)
with torch.no_grad(): # Initialize lazy modules.
out = model(data.x_dict, data.edge_index_dict)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001)
def train():
model.train()
optimizer.zero_grad()
out = model(data.x_dict, data.edge_index_dict)
mask = data['author'].train_mask
loss = F.cross_entropy(out[mask], data['author'].y[mask])
loss.backward()
optimizer.step()
return float(loss)
@torch.no_grad()
def test():
model.eval()
pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1)
accs = []
for split in ['train_mask', 'val_mask', 'test_mask']:
mask = data['author'][split]
acc = (pred[mask] == data['author'].y[mask]).sum() / mask.sum()
accs.append(float(acc))
return accs
for epoch in range(1, 101):
loss = train()
train_acc, val_acc, test_acc = test()
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '
f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')
输出:
HeteroData(
author={
x=[4057, 334],
y=[4057],
train_mask=[4057],
val_mask=[4057],
test_mask=[4057]
},
paper={ x=[14328, 4231] },
term={ x=[7723, 50] },
conference={
num_nodes=20,
x=[20, 1]
},
(author, to, paper)={ edge_index=[2, 19645] },
(paper, to, author)={ edge_index=[2, 19645] },
(paper, to, term)={ edge_index=[2, 85810] },
(paper, to, conference)={ edge_index=[2, 14328] },
(term, to, paper)={ edge_index=[2, 85810] },
(conference, to, paper)={ edge_index=[2, 14328] }
)
Epoch: 001, Loss: 1.3967, Train: 0.2550, Val: 0.2700, Test: 0.2539
Epoch: 002, Loss: 1.3708, Train: 0.6750, Val: 0.4825, Test: 0.5272
Epoch: 003, Loss: 1.3459, Train: 0.6200, Val: 0.4525, Test: 0.5302
Epoch: 004, Loss: 1.3173, Train: 0.5675, Val: 0.4300, Test: 0.4992
Epoch: 005, Loss: 1.2809, Train: 0.5550, Val: 0.4150, Test: 0.4836
Epoch: 006, Loss: 1.2323, Train: 0.5825, Val: 0.4175, Test: 0.4814
Epoch: 007, Loss: 1.1662, Train: 0.6575, Val: 0.4325, Test: 0.5026
Epoch: 008, Loss: 1.0761, Train: 0.7475, Val: 0.4775, Test: 0.5511
Epoch: 009, Loss: 0.9564, Train: 0.8425, Val: 0.5575, Test: 0.6098
Epoch: 010, Loss: 0.8064, Train: 0.9400, Val: 0.6275, Test: 0.6831
Epoch: 011, Loss: 0.6348, Train: 0.9750, Val: 0.7125, Test: 0.7446
Epoch: 012, Loss: 0.4627, Train: 0.9950, Val: 0.7225, Test: 0.7768
Epoch: 013, Loss: 0.3151, Train: 0.9975, Val: 0.7275, Test: 0.7842
Epoch: 014, Loss: 0.2002, Train: 0.9975, Val: 0.7200, Test: 0.7835
Epoch: 015, Loss: 0.1142, Train: 0.9950, Val: 0.7200, Test: 0.7706
Epoch: 016, Loss: 0.0614, Train: 0.9950, Val: 0.7250, Test: 0.7596
Epoch: 017, Loss: 0.0336, Train: 1.0000, Val: 0.7125, Test: 0.7633
Epoch: 018, Loss: 0.0163, Train: 1.0000, Val: 0.6950, Test: 0.7676
Epoch: 019, Loss: 0.0068, Train: 1.0000, Val: 0.7150, Test: 0.7688
Epoch: 020, Loss: 0.0033, Train: 1.0000, Val: 0.7175, Test: 0.7630
Epoch: 021, Loss: 0.0022, Train: 1.0000, Val: 0.7200, Test: 0.7630
Epoch: 022, Loss: 0.0016, Train: 1.0000, Val: 0.7175, Test: 0.7587
Epoch: 023, Loss: 0.0010, Train: 1.0000, Val: 0.7175, Test: 0.7624
Epoch: 024, Loss: 0.0006, Train: 1.0000, Val: 0.7300, Test: 0.7621
Epoch: 025, Loss: 0.0003, Train: 1.0000, Val: 0.7300, Test: 0.7599
Epoch: 026, Loss: 0.0002, Train: 1.0000, Val: 0.7325, Test: 0.7571
Epoch: 027, Loss: 0.0002, Train: 1.0000, Val: 0.7425, Test: 0.7608
Epoch: 028, Loss: 0.0002, Train: 1.0000, Val: 0.7425, Test: 0.7633
Epoch: 029, Loss: 0.0002, Train: 1.0000, Val: 0.7475, Test: 0.7624
Epoch: 030, Loss: 0.0003, Train: 1.0000, Val: 0.7525, Test: 0.7636
Epoch: 031, Loss: 0.0005, Train: 1.0000, Val: 0.7475, Test: 0.7667
Epoch: 032, Loss: 0.0006, Train: 1.0000, Val: 0.7475, Test: 0.7663
Epoch: 033, Loss: 0.0007, Train: 1.0000, Val: 0.7550, Test: 0.7673
Epoch: 034, Loss: 0.0008, Train: 1.0000, Val: 0.7600, Test: 0.7682
Epoch: 035, Loss: 0.0008, Train: 1.0000, Val: 0.7625, Test: 0.7703
Epoch: 036, Loss: 0.0008, Train: 1.0000, Val: 0.7650, Test: 0.7722
Epoch: 037, Loss: 0.0008, Train: 1.0000, Val: 0.7650, Test: 0.7759
Epoch: 038, Loss: 0.0008, Train: 1.0000, Val: 0.7625, Test: 0.7722
Epoch: 039, Loss: 0.0009, Train: 1.0000, Val: 0.7550, Test: 0.7756
Epoch: 040, Loss: 0.0010, Train: 1.0000, Val: 0.7550, Test: 0.7734
Epoch: 041, Loss: 0.0010, Train: 1.0000, Val: 0.7525, Test: 0.7749
Epoch: 042, Loss: 0.0011, Train: 1.0000, Val: 0.7475, Test: 0.7743
Epoch: 043, Loss: 0.0011, Train: 1.0000, Val: 0.7500, Test: 0.7753
Epoch: 044, Loss: 0.0011, Train: 1.0000, Val: 0.7525, Test: 0.7746
Epoch: 045, Loss: 0.0012, Train: 1.0000, Val: 0.7500, Test: 0.7749
Epoch: 046, Loss: 0.0012, Train: 1.0000, Val: 0.7550, Test: 0.7762
Epoch: 047, Loss: 0.0013, Train: 1.0000, Val: 0.7575, Test: 0.7792
Epoch: 048, Loss: 0.0015, Train: 1.0000, Val: 0.7550, Test: 0.7808
Epoch: 049, Loss: 0.0016, Train: 1.0000, Val: 0.7525, Test: 0.7783
Epoch: 050, Loss: 0.0016, Train: 1.0000, Val: 0.7575, Test: 0.7808
Epoch: 051, Loss: 0.0016, Train: 1.0000, Val: 0.7600, Test: 0.7811
Epoch: 052, Loss: 0.0016, Train: 1.0000, Val: 0.7625, Test: 0.7842
Epoch: 053, Loss: 0.0017, Train: 1.0000, Val: 0.7600, Test: 0.7823
Epoch: 054, Loss: 0.0018, Train: 1.0000, Val: 0.7600, Test: 0.7835
Epoch: 055, Loss: 0.0019, Train: 1.0000, Val: 0.7600, Test: 0.7808
Epoch: 056, Loss: 0.0019, Train: 1.0000, Val: 0.7575, Test: 0.7820
Epoch: 057, Loss: 0.0019, Train: 1.0000, Val: 0.7600, Test: 0.7832
Epoch: 058, Loss: 0.0020, Train: 1.0000, Val: 0.7625, Test: 0.7848
Epoch: 059, Loss: 0.0021, Train: 1.0000, Val: 0.7625, Test: 0.7845
Epoch: 060, Loss: 0.0021, Train: 1.0000, Val: 0.7625, Test: 0.7839
Epoch: 061, Loss: 0.0022, Train: 1.0000, Val: 0.7650, Test: 0.7826
Epoch: 062, Loss: 0.0023, Train: 1.0000, Val: 0.7700, Test: 0.7826
Epoch: 063, Loss: 0.0023, Train: 1.0000, Val: 0.7700, Test: 0.7848
Epoch: 064, Loss: 0.0024, Train: 1.0000, Val: 0.7700, Test: 0.7820
Epoch: 065, Loss: 0.0025, Train: 1.0000, Val: 0.7700, Test: 0.7839
Epoch: 066, Loss: 0.0025, Train: 1.0000, Val: 0.7675, Test: 0.7826
Epoch: 067, Loss: 0.0026, Train: 1.0000, Val: 0.7675, Test: 0.7832
Epoch: 068, Loss: 0.0026, Train: 1.0000, Val: 0.7650, Test: 0.7854
Epoch: 069, Loss: 0.0027, Train: 1.0000, Val: 0.7650, Test: 0.7863
Epoch: 070, Loss: 0.0027, Train: 1.0000, Val: 0.7625, Test: 0.7866
Epoch: 071, Loss: 0.0028, Train: 1.0000, Val: 0.7625, Test: 0.7860
Epoch: 072, Loss: 0.0028, Train: 1.0000, Val: 0.7625, Test: 0.7872
Epoch: 073, Loss: 0.0028, Train: 1.0000, Val: 0.7625, Test: 0.7872
Epoch: 074, Loss: 0.0028, Train: 1.0000, Val: 0.7625, Test: 0.7860
Epoch: 075, Loss: 0.0029, Train: 1.0000, Val: 0.7625, Test: 0.7854
Epoch: 076, Loss: 0.0029, Train: 1.0000, Val: 0.7650, Test: 0.7863
Epoch: 077, Loss: 0.0029, Train: 1.0000, Val: 0.7600, Test: 0.7866
Epoch: 078, Loss: 0.0029, Train: 1.0000, Val: 0.7625, Test: 0.7875
Epoch: 079, Loss: 0.0030, Train: 1.0000, Val: 0.7625, Test: 0.7872
Epoch: 080, Loss: 0.0030, Train: 1.0000, Val: 0.7625, Test: 0.7885
Epoch: 081, Loss: 0.0030, Train: 1.0000, Val: 0.7625, Test: 0.7897
Epoch: 082, Loss: 0.0030, Train: 1.0000, Val: 0.7600, Test: 0.7894
Epoch: 083, Loss: 0.0030, Train: 1.0000, Val: 0.7600, Test: 0.7897
Epoch: 084, Loss: 0.0030, Train: 1.0000, Val: 0.7625, Test: 0.7900
Epoch: 085, Loss: 0.0030, Train: 1.0000, Val: 0.7625, Test: 0.7897
Epoch: 086, Loss: 0.0031, Train: 1.0000, Val: 0.7650, Test: 0.7903
Epoch: 087, Loss: 0.0031, Train: 1.0000, Val: 0.7650, Test: 0.7906
Epoch: 088, Loss: 0.0031, Train: 1.0000, Val: 0.7650, Test: 0.7909
Epoch: 089, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7915
Epoch: 090, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7915
Epoch: 091, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7915
Epoch: 092, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7912
Epoch: 093, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7909
Epoch: 094, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7909
Epoch: 095, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7909
Epoch: 096, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7909
Epoch: 097, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7906
Epoch: 098, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7912
Epoch: 099, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7912
Epoch: 100, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7915