开启掘金成长之旅!这是我参与「掘金日新计划 · 12 月更文挑战」的第14天
本文首发于CSDN。
5.1.2 使用HeteroConv
5.1.2.1 GraphSAGE
(对原数据集中没有特征的节点,用[1.]作为初始特征)
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.datasets import DBLP
from torch_geometric.nn import HeteroConv, Linear, SAGEConv
# We initialize conference node features with a single one-vector as feature:
dataset = DBLP('/data/pyg_data/DBLP', transform=T.Constant(node_types='conference'))
data = dataset[0]
print(data)
class HeteroGNN(torch.nn.Module):
def __init__(self, metadata, hidden_channels, out_channels, num_layers):
super().__init__()
self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
conv = HeteroConv({
edge_type: SAGEConv((-1, -1), hidden_channels)
for edge_type in metadata[1]
})
self.convs.append(conv)
self.lin = Linear(hidden_channels, out_channels)
def forward(self, x_dict, edge_index_dict):
for conv in self.convs:
x_dict = conv(x_dict, edge_index_dict)
x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()}
return self.lin(x_dict['author'])
model = HeteroGNN(data.metadata(), hidden_channels=64, out_channels=4,
num_layers=2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data, model = data.to(device), model.to(device)
with torch.no_grad(): # Initialize lazy modules.
out = model(data.x_dict, data.edge_index_dict)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001)
def train():
model.train()
optimizer.zero_grad()
out = model(data.x_dict, data.edge_index_dict)
mask = data['author'].train_mask
loss = F.cross_entropy(out[mask], data['author'].y[mask])
loss.backward()
optimizer.step()
return float(loss)
@torch.no_grad()
def test():
model.eval()
pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1)
accs = []
for split in ['train_mask', 'val_mask', 'test_mask']:
mask = data['author'][split]
acc = (pred[mask] == data['author'].y[mask]).sum() / mask.sum()
accs.append(float(acc))
return accs
for epoch in range(1, 101):
loss = train()
train_acc, val_acc, test_acc = test()
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '
f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')
输出:
HeteroData(
author={
x=[4057, 334],
y=[4057],
train_mask=[4057],
val_mask=[4057],
test_mask=[4057]
},
paper={ x=[14328, 4231] },
term={ x=[7723, 50] },
conference={
num_nodes=20,
x=[20, 1]
},
(author, to, paper)={ edge_index=[2, 19645] },
(paper, to, author)={ edge_index=[2, 19645] },
(paper, to, term)={ edge_index=[2, 85810] },
(paper, to, conference)={ edge_index=[2, 14328] },
(term, to, paper)={ edge_index=[2, 85810] },
(conference, to, paper)={ edge_index=[2, 14328] }
)
Epoch: 001, Loss: 1.3721, Train: 0.4550, Val: 0.3450, Test: 0.3819
Epoch: 002, Loss: 1.2867, Train: 0.6050, Val: 0.4800, Test: 0.5333
Epoch: 003, Loss: 1.1778, Train: 0.7175, Val: 0.5325, Test: 0.5941
Epoch: 004, Loss: 1.0368, Train: 0.8350, Val: 0.6050, Test: 0.6788
Epoch: 005, Loss: 0.8729, Train: 0.8950, Val: 0.6725, Test: 0.7228
Epoch: 006, Loss: 0.6991, Train: 0.9350, Val: 0.7025, Test: 0.7479
Epoch: 007, Loss: 0.5314, Train: 0.9600, Val: 0.7350, Test: 0.7765
Epoch: 008, Loss: 0.3831, Train: 0.9825, Val: 0.7525, Test: 0.8010
Epoch: 009, Loss: 0.2585, Train: 0.9900, Val: 0.7800, Test: 0.8189
Epoch: 010, Loss: 0.1632, Train: 0.9975, Val: 0.8025, Test: 0.8284
Epoch: 011, Loss: 0.0988, Train: 0.9975, Val: 0.8100, Test: 0.8293
Epoch: 012, Loss: 0.0578, Train: 1.0000, Val: 0.8025, Test: 0.8290
Epoch: 013, Loss: 0.0324, Train: 1.0000, Val: 0.8150, Test: 0.8296
Epoch: 014, Loss: 0.0178, Train: 1.0000, Val: 0.8075, Test: 0.8281
Epoch: 015, Loss: 0.0100, Train: 1.0000, Val: 0.8050, Test: 0.8281
Epoch: 016, Loss: 0.0060, Train: 1.0000, Val: 0.8050, Test: 0.8262
Epoch: 017, Loss: 0.0039, Train: 1.0000, Val: 0.8025, Test: 0.8235
Epoch: 018, Loss: 0.0027, Train: 1.0000, Val: 0.8100, Test: 0.8232
Epoch: 019, Loss: 0.0020, Train: 1.0000, Val: 0.8125, Test: 0.8235
Epoch: 020, Loss: 0.0017, Train: 1.0000, Val: 0.8125, Test: 0.8238
Epoch: 021, Loss: 0.0015, Train: 1.0000, Val: 0.8175, Test: 0.8268
Epoch: 022, Loss: 0.0014, Train: 1.0000, Val: 0.8175, Test: 0.8271
Epoch: 023, Loss: 0.0013, Train: 1.0000, Val: 0.8125, Test: 0.8274
Epoch: 024, Loss: 0.0014, Train: 1.0000, Val: 0.8150, Test: 0.8284
Epoch: 025, Loss: 0.0014, Train: 1.0000, Val: 0.8175, Test: 0.8281
Epoch: 026, Loss: 0.0015, Train: 1.0000, Val: 0.8175, Test: 0.8268
Epoch: 027, Loss: 0.0017, Train: 1.0000, Val: 0.8225, Test: 0.8225
Epoch: 028, Loss: 0.0019, Train: 1.0000, Val: 0.8250, Test: 0.8216
Epoch: 029, Loss: 0.0021, Train: 1.0000, Val: 0.8250, Test: 0.8198
Epoch: 030, Loss: 0.0024, Train: 1.0000, Val: 0.8225, Test: 0.8195
Epoch: 031, Loss: 0.0026, Train: 1.0000, Val: 0.8200, Test: 0.8192
Epoch: 032, Loss: 0.0029, Train: 1.0000, Val: 0.8225, Test: 0.8189
Epoch: 033, Loss: 0.0032, Train: 1.0000, Val: 0.8175, Test: 0.8185
Epoch: 034, Loss: 0.0035, Train: 1.0000, Val: 0.8200, Test: 0.8185
Epoch: 035, Loss: 0.0037, Train: 1.0000, Val: 0.8200, Test: 0.8176
Epoch: 036, Loss: 0.0038, Train: 1.0000, Val: 0.8225, Test: 0.8185
Epoch: 037, Loss: 0.0039, Train: 1.0000, Val: 0.8200, Test: 0.8176
Epoch: 038, Loss: 0.0041, Train: 1.0000, Val: 0.8175, Test: 0.8192
Epoch: 039, Loss: 0.0043, Train: 1.0000, Val: 0.8175, Test: 0.8204
Epoch: 040, Loss: 0.0044, Train: 1.0000, Val: 0.8150, Test: 0.8189
Epoch: 041, Loss: 0.0045, Train: 1.0000, Val: 0.8150, Test: 0.8173
Epoch: 042, Loss: 0.0046, Train: 1.0000, Val: 0.8175, Test: 0.8179
Epoch: 043, Loss: 0.0047, Train: 1.0000, Val: 0.8150, Test: 0.8170
Epoch: 044, Loss: 0.0047, Train: 1.0000, Val: 0.8175, Test: 0.8185
Epoch: 045, Loss: 0.0047, Train: 1.0000, Val: 0.8125, Test: 0.8195
Epoch: 046, Loss: 0.0047, Train: 1.0000, Val: 0.8150, Test: 0.8192
Epoch: 047, Loss: 0.0047, Train: 1.0000, Val: 0.8125, Test: 0.8182
Epoch: 048, Loss: 0.0047, Train: 1.0000, Val: 0.8075, Test: 0.8167
Epoch: 049, Loss: 0.0047, Train: 1.0000, Val: 0.8050, Test: 0.8158
Epoch: 050, Loss: 0.0047, Train: 1.0000, Val: 0.8000, Test: 0.8167
Epoch: 051, Loss: 0.0047, Train: 1.0000, Val: 0.8050, Test: 0.8170
Epoch: 052, Loss: 0.0047, Train: 1.0000, Val: 0.8100, Test: 0.8152
Epoch: 053, Loss: 0.0046, Train: 1.0000, Val: 0.8075, Test: 0.8149
Epoch: 054, Loss: 0.0046, Train: 1.0000, Val: 0.8075, Test: 0.8133
Epoch: 055, Loss: 0.0046, Train: 1.0000, Val: 0.8100, Test: 0.8139
Epoch: 056, Loss: 0.0046, Train: 1.0000, Val: 0.8100, Test: 0.8152
Epoch: 057, Loss: 0.0045, Train: 1.0000, Val: 0.8050, Test: 0.8149
Epoch: 058, Loss: 0.0045, Train: 1.0000, Val: 0.8025, Test: 0.8146
Epoch: 059, Loss: 0.0045, Train: 1.0000, Val: 0.8100, Test: 0.8139
Epoch: 060, Loss: 0.0044, Train: 1.0000, Val: 0.8125, Test: 0.8149
Epoch: 061, Loss: 0.0044, Train: 1.0000, Val: 0.8100, Test: 0.8149
Epoch: 062, Loss: 0.0043, Train: 1.0000, Val: 0.8025, Test: 0.8130
Epoch: 063, Loss: 0.0043, Train: 1.0000, Val: 0.8050, Test: 0.8124
Epoch: 064, Loss: 0.0042, Train: 1.0000, Val: 0.8050, Test: 0.8124
Epoch: 065, Loss: 0.0042, Train: 1.0000, Val: 0.8100, Test: 0.8127
Epoch: 066, Loss: 0.0041, Train: 1.0000, Val: 0.8100, Test: 0.8130
Epoch: 067, Loss: 0.0041, Train: 1.0000, Val: 0.8050, Test: 0.8121
Epoch: 068, Loss: 0.0040, Train: 1.0000, Val: 0.8050, Test: 0.8124
Epoch: 069, Loss: 0.0040, Train: 1.0000, Val: 0.8100, Test: 0.8118
Epoch: 070, Loss: 0.0039, Train: 1.0000, Val: 0.8075, Test: 0.8115
Epoch: 071, Loss: 0.0038, Train: 1.0000, Val: 0.8050, Test: 0.8109
Epoch: 072, Loss: 0.0038, Train: 1.0000, Val: 0.8075, Test: 0.8109
Epoch: 073, Loss: 0.0037, Train: 1.0000, Val: 0.8100, Test: 0.8099
Epoch: 074, Loss: 0.0037, Train: 1.0000, Val: 0.8100, Test: 0.8109
Epoch: 075, Loss: 0.0036, Train: 1.0000, Val: 0.8050, Test: 0.8106
Epoch: 076, Loss: 0.0036, Train: 1.0000, Val: 0.8075, Test: 0.8115
Epoch: 077, Loss: 0.0036, Train: 1.0000, Val: 0.8100, Test: 0.8112
Epoch: 078, Loss: 0.0035, Train: 1.0000, Val: 0.8075, Test: 0.8112
Epoch: 079, Loss: 0.0035, Train: 1.0000, Val: 0.8050, Test: 0.8112
Epoch: 080, Loss: 0.0035, Train: 1.0000, Val: 0.8050, Test: 0.8112
Epoch: 081, Loss: 0.0034, Train: 1.0000, Val: 0.8050, Test: 0.8115
Epoch: 082, Loss: 0.0034, Train: 1.0000, Val: 0.8050, Test: 0.8109
Epoch: 083, Loss: 0.0034, Train: 1.0000, Val: 0.8050, Test: 0.8109
Epoch: 084, Loss: 0.0034, Train: 1.0000, Val: 0.8050, Test: 0.8112
Epoch: 085, Loss: 0.0033, Train: 1.0000, Val: 0.8050, Test: 0.8106
Epoch: 086, Loss: 0.0033, Train: 1.0000, Val: 0.8025, Test: 0.8103
Epoch: 087, Loss: 0.0033, Train: 1.0000, Val: 0.8025, Test: 0.8103
Epoch: 088, Loss: 0.0032, Train: 1.0000, Val: 0.8025, Test: 0.8099
Epoch: 089, Loss: 0.0032, Train: 1.0000, Val: 0.8025, Test: 0.8106
Epoch: 090, Loss: 0.0032, Train: 1.0000, Val: 0.8025, Test: 0.8103
Epoch: 091, Loss: 0.0032, Train: 1.0000, Val: 0.8025, Test: 0.8106
Epoch: 092, Loss: 0.0031, Train: 1.0000, Val: 0.8025, Test: 0.8106
Epoch: 093, Loss: 0.0031, Train: 1.0000, Val: 0.8025, Test: 0.8109
Epoch: 094, Loss: 0.0031, Train: 1.0000, Val: 0.8025, Test: 0.8115
Epoch: 095, Loss: 0.0031, Train: 1.0000, Val: 0.8000, Test: 0.8121
Epoch: 096, Loss: 0.0030, Train: 1.0000, Val: 0.8025, Test: 0.8124
Epoch: 097, Loss: 0.0030, Train: 1.0000, Val: 0.8000, Test: 0.8130
Epoch: 098, Loss: 0.0030, Train: 1.0000, Val: 0.8025, Test: 0.8127
Epoch: 099, Loss: 0.0030, Train: 1.0000, Val: 0.7975, Test: 0.8124
Epoch: 100, Loss: 0.0030, Train: 1.0000, Val: 0.8000, Test: 0.8127