开启掘金成长之旅!这是我参与「掘金日新计划 · 12 月更文挑战」的第14天
本文首发于CSDN。
5.2 mini-batch
可用的DataLoader: pytorch-geometric.readthedocs.io/en/latest/m… pytorch-geometric.readthedocs.io/en/latest/m…
跟同质图一样,还挺方便的,就直接返回HeteroData对象
建立DataLoader的代码模板:
import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.loader import NeighborLoader
transform = T.ToUndirected() # Add reverse edge types.
data = OGB_MAG(root='./data', preprocess='metapath2vec', transform=transform)[0]
train_loader = NeighborLoader(
data,
# Sample 15 neighbors for each node and each edge type for 2 iterations:
num_neighbors=[15] * 2,
# Use a batch size of 128 for sampling training nodes of type "paper":
batch_size=128,
input_nodes=('paper', data['paper'].train_mask),
)
batch = next(iter(train_loader))
可以使用更细粒度的邻居数控制:num_neighbors = {key: [15] * 2 for key in data.edge_types}
就是这个batch_size是说用于计算这么多节点嵌入,需要用整个batch(前batch_size个嵌入就是这些要的嵌入)
训练的代码模板:
def train():
model.train()
total_examples = total_loss = 0
for batch in train_loader:
optimizer.zero_grad()
batch = batch.to('cuda:0')
batch_size = batch['paper'].batch_size
out = model(batch.x_dict, batch.edge_index_dict)
loss = F.cross_entropy(out['paper'][:batch_size],
batch['paper'].y[:batch_size])
loss.backward()
optimizer.step()
total_examples += batch_size
total_loss += float(loss) * batch_size
return total_loss / total_examples
直接使用NeighborLoader多进程会出现这个奇怪的问题,所以建议用单进程:Heterogenous graph, use NeighborLoader with num_workers>0, and stucks after many epochs · Issue #5348 · pyg-team/pytorch_geometric
如果想要获得mini-batch节点对应原图中的索引,可以参考这个discussion:I wonder how to use NeighborLoader correctly? · Discussion #3409 · pyg-team/pytorch_geometric
示例代码(参考github.com/pyg-team/py…):
import argparse
import os.path as osp
import torch
import torch.nn.functional as F
from torch.nn import ReLU
from tqdm import tqdm
import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.loader import HGTLoader, NeighborLoader
from torch_geometric.nn import Linear, SAGEConv, Sequential, to_hetero
parser = argparse.ArgumentParser()
parser.add_argument('--use_hgt_loader', action='store_true')
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = T.ToUndirected(merge=True)
dataset = OGB_MAG('/data/pyg_data', preprocess='metapath2vec', transform=transform)
# Already send node features/labels to GPU for faster access during sampling:
data = dataset[0].to(device, 'x', 'y')
train_input_nodes = ('paper', data['paper'].train_mask)
val_input_nodes = ('paper', data['paper'].val_mask)
kwargs = {'batch_size': 1024, 'num_workers': 6, 'persistent_workers': True}
if not args.use_hgt_loader:
train_loader = NeighborLoader(data, num_neighbors=[10] * 2, shuffle=True,
input_nodes=train_input_nodes, **kwargs)
val_loader = NeighborLoader(data, num_neighbors=[10] * 2,
input_nodes=val_input_nodes, **kwargs)
else:
train_loader = HGTLoader(data, num_samples=[1024] * 4, shuffle=True,
input_nodes=train_input_nodes, **kwargs)
val_loader = HGTLoader(data, num_samples=[1024] * 4,
input_nodes=val_input_nodes, **kwargs)
model = Sequential('x, edge_index', [
(SAGEConv((-1, -1), 64), 'x, edge_index -> x'),
ReLU(inplace=True),
(SAGEConv((-1, -1), 64), 'x, edge_index -> x'),
ReLU(inplace=True),
(Linear(-1, dataset.num_classes), 'x -> x'),
])
model = to_hetero(model, data.metadata(), aggr='sum').to(device)
@torch.no_grad()
def init_params():
# Initialize lazy parameters via forwarding a single batch to the model:
batch = next(iter(train_loader))
batch = batch.to(device, 'edge_index')
model(batch.x_dict, batch.edge_index_dict)
def train():
model.train()
total_examples = total_loss = 0
for batch in tqdm(train_loader):
optimizer.zero_grad()
batch = batch.to(device, 'edge_index')
batch_size = batch['paper'].batch_size
out = model(batch.x_dict, batch.edge_index_dict)['paper'][:batch_size]
loss = F.cross_entropy(out, batch['paper'].y[:batch_size])
loss.backward()
optimizer.step()
total_examples += batch_size
total_loss += float(loss) * batch_size
return total_loss / total_examples
@torch.no_grad()
def test(loader):
model.eval()
total_examples = total_correct = 0
for batch in tqdm(loader):
batch = batch.to(device, 'edge_index')
batch_size = batch['paper'].batch_size
out = model(batch.x_dict, batch.edge_index_dict)['paper'][:batch_size]
pred = out.argmax(dim=-1)
total_examples += batch_size
total_correct += int((pred == batch['paper'].y[:batch_size]).sum())
return total_correct / total_examples
init_params() # Initialize parameters.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(1, 21):
loss = train()
val_acc = test(val_loader)
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_acc:.4f}')
NeighborLoader最后一个epoch的输出结果:Epoch: 20, Loss: 1.9040, Val: 0.4445
HGTLoader最后一个epoch的输出结果:Epoch: 20, Loss: 2.0077, Val: 0.4271
(因为有进度条,所以太长了,所以就不放全部输出了)
示例代码注意事项:PyG的数据对象用to可以单独挑出一些属性转换设备