PyG (PyTorch Geometric) 异质图神经网络HGNN(8)

230 阅读11分钟

开启掘金成长之旅!这是我参与「掘金日新计划 · 12 月更文挑战」的第14天

本文首发于CSDN。

6. 链路预测任务

6.1 transductive

6.1.1 GraphSAGE编码+MLP解码+预测用户打分

参考:github.com/pyg-team/py…

  1. 用GraphSAGE转换为异质图GNN,做节点编码
  2. 对节点对表征的解码(也就是得到节点对链路预测得分的过程):将节点对特征concat后,通过2层MLP
  3. 不算是标准的链路预测任务,因为这个任务是预测已知节点对的['user','rates','movie']得分(取值范围是0-5的离散整数),所以是用回归任务来做的(损失函数是加权MSE:因为6种打分之间不平衡)
  4. 在测试时把预测结果截断到0-5之间再计算RMSE值,作为输出指标
  5. 使用MovieLens数据集。原数据集中的节点有两种
    1. 电影节点,仅有文本特征(标题),代码中用SentenceTransformer模型进行句子表征,注意这个model_name属性如果直接用本地模型会出现问题,解决方式就是粗暴的直接用本地路径跑一次,然后把存储后的对象改成模型名,然后就用模型名直接调用。不太好解释,直接看这个issue吧(我提了个PR,但是不知道为啥作者没有merge,所以还需手动修改):Unable to process movie_lens dataset with local directory transformers model · Issue #5500 · pyg-team/pytorch_geometric
    2. 用户节点:没有特征,代码中用独热编码作为初始节点特征
    3. 转换为无向图(产生逆向边)
  6. 数据分割:8-1-1随机划分边,节点不变,训练集和验证集图中用的边相同,测试集用的边在训练集基础上增加验证集计算指标用的边。因为是已知节点对的回归任务,所以不需要负边
import argparse

import torch
import torch.nn.functional as F
from torch.nn import Linear

import torch_geometric.transforms as T
from torch_geometric.datasets import MovieLens
from torch_geometric.nn import SAGEConv, to_hetero

parser = argparse.ArgumentParser()
parser.add_argument('--use_weighted_loss', action='store_true',
                    help='Whether to use weighted MSE loss.')
args = parser.parse_args()

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

dataset = MovieLens('/data/pyg_data/MovieLens', model_name='all-MiniLM-L6-v2')
data = dataset[0].to(device)

# Add user node features for message passing:
data['user'].x = torch.eye(data['user'].num_nodes, device=device)
del data['user'].num_nodes

# Add a reverse ('movie', 'rev_rates', 'user') relation for message passing:
data = T.ToUndirected()(data)
del data['movie', 'rev_rates', 'user'].edge_label  # Remove "reverse" label.

# Perform a link-level split into training, validation, and test edges:
train_data, val_data, test_data = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    neg_sampling_ratio=0.0,
    edge_types=[('user', 'rates', 'movie')],
    rev_edge_types=[('movie', 'rev_rates', 'user')],
)(data)

# We have an unbalanced dataset with many labels for rating 3 and 4, and very
# few for 0 and 1. Therefore we use a weighted MSE loss.
if args.use_weighted_loss:
    weight = torch.bincount(train_data['user', 'movie'].edge_label)
    weight = weight.max() / weight
else:
    weight = None


def weighted_mse_loss(pred, target, weight=None):
    weight = 1. if weight is None else weight[target].to(pred.dtype)
    return (weight * (pred - target.to(pred.dtype)).pow(2)).mean()


class GNNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x


class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.lin1 = Linear(2 * hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, 1)

    def forward(self, z_dict, edge_label_index):
        row, col = edge_label_index
        z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1)

        z = self.lin1(z).relu()
        z = self.lin2(z)
        return z.view(-1)


class Model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.encoder = GNNEncoder(hidden_channels, hidden_channels)
        self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
        self.decoder = EdgeDecoder(hidden_channels)

    def forward(self, x_dict, edge_index_dict, edge_label_index):
        z_dict = self.encoder(x_dict, edge_index_dict)
        return self.decoder(z_dict, edge_label_index)


model = Model(hidden_channels=32).to(device)

# Due to lazy initialization, we need to run one model step so the number
# of parameters can be inferred:
with torch.no_grad():
    model.encoder(train_data.x_dict, train_data.edge_index_dict)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


def train():
    model.train()
    optimizer.zero_grad()
    pred = model(train_data.x_dict, train_data.edge_index_dict,
                 train_data['user', 'movie'].edge_label_index)
    target = train_data['user', 'movie'].edge_label
    loss = weighted_mse_loss(pred, target, weight)
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test(data):
    model.eval()
    pred = model(data.x_dict, data.edge_index_dict,
                 data['user', 'movie'].edge_label_index)
    pred = pred.clamp(min=0, max=5)
    target = data['user', 'movie'].edge_label.float()
    rmse = F.mse_loss(pred, target).sqrt()
    return float(rmse)


for epoch in range(1, 301):
    loss = train()
    train_rmse = test(train_data)
    val_rmse = test(val_data)
    test_rmse = test(test_data)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_rmse:.4f}, '
          f'Val: {val_rmse:.4f}, Test: {test_rmse:.4f}')

输出:

HeteroData(
  movie={ x=[9742, 404] },
  user={ x=[610, 610] },
  (user, rates, movie)={
    edge_index=[2, 100836],
    edge_label=[100836]
  },
  (movie, rev_rates, user)={ edge_index=[2, 100836] }
)
Epoch: 001, Loss: 11.1455, Train: 3.0880, Val: 3.0996, Test: 3.0917
Epoch: 002, Loss: 9.5358, Train: 2.6658, Val: 2.6792, Test: 2.6712
Epoch: 003, Loss: 7.1066, Train: 1.8713, Val: 1.8877, Test: 1.8804
Epoch: 004, Loss: 3.5019, Train: 1.1067, Val: 1.0977, Test: 1.1063
Epoch: 005, Loss: 1.2249, Train: 1.9740, Val: 1.9311, Test: 1.9472
Epoch: 006, Loss: 5.5210, Train: 1.6758, Val: 1.6408, Test: 1.6595
Epoch: 007, Loss: 2.8131, Train: 1.0975, Val: 1.0894, Test: 1.0977
Epoch: 008, Loss: 1.2045, Train: 1.2613, Val: 1.2734, Test: 1.2708
Epoch: 009, Loss: 1.5908, Train: 1.5404, Val: 1.5562, Test: 1.5505
Epoch: 010, Loss: 2.3730, Train: 1.6708, Val: 1.6869, Test: 1.6805
Epoch: 011, Loss: 2.7915, Train: 1.6564, Val: 1.6724, Test: 1.6661
Epoch: 012, Loss: 2.7436, Train: 1.5262, Val: 1.5418, Test: 1.5362
Epoch: 013, Loss: 2.3292, Train: 1.3166, Val: 1.3301, Test: 1.3266
Epoch: 014, Loss: 1.7333, Train: 1.1141, Val: 1.1205, Test: 1.1219
Epoch: 015, Loss: 1.2412, Train: 1.0907, Val: 1.0819, Test: 1.0912
Epoch: 016, Loss: 1.1896, Train: 1.2743, Val: 1.2524, Test: 1.2672
Epoch: 017, Loss: 1.6238, Train: 1.3905, Val: 1.3643, Test: 1.3808
Epoch: 018, Loss: 1.9334, Train: 1.3026, Val: 1.2797, Test: 1.2951
Epoch: 019, Loss: 1.6968, Train: 1.1325, Val: 1.1190, Test: 1.1307
Epoch: 020, Loss: 1.2826, Train: 1.0549, Val: 1.0534, Test: 1.0595
Epoch: 021, Loss: 1.1128, Train: 1.1005, Val: 1.1076, Test: 1.1089
Epoch: 022, Loss: 1.2111, Train: 1.1792, Val: 1.1902, Test: 1.1889
Epoch: 023, Loss: 1.3905, Train: 1.2222, Val: 1.2345, Test: 1.2323
Epoch: 024, Loss: 1.4938, Train: 1.2101, Val: 1.2222, Test: 1.2202
Epoch: 025, Loss: 1.4643, Train: 1.1525, Val: 1.1631, Test: 1.1623
Epoch: 026, Loss: 1.3283, Train: 1.0802, Val: 1.0871, Test: 1.0889
Epoch: 027, Loss: 1.1668, Train: 1.0385, Val: 1.0392, Test: 1.0446
Epoch: 028, Loss: 1.0784, Train: 1.0565, Val: 1.0501, Test: 1.0592
Epoch: 029, Loss: 1.1163, Train: 1.1062, Val: 1.0946, Test: 1.1060
Epoch: 030, Loss: 1.2236, Train: 1.1266, Val: 1.1135, Test: 1.1257
Epoch: 031, Loss: 1.2693, Train: 1.0953, Val: 1.0846, Test: 1.0959
Epoch: 032, Loss: 1.1998, Train: 1.0460, Val: 1.0404, Test: 1.0495
Epoch: 033, Loss: 1.0942, Train: 1.0229, Val: 1.0234, Test: 1.0295
Epoch: 034, Loss: 1.0464, Train: 1.0346, Val: 1.0399, Test: 1.0434
Epoch: 035, Loss: 1.0705, Train: 1.0578, Val: 1.0659, Test: 1.0677
Epoch: 036, Loss: 1.1189, Train: 1.0684, Val: 1.0775, Test: 1.0787
Epoch: 037, Loss: 1.1414, Train: 1.0578, Val: 1.0665, Test: 1.0681
Epoch: 038, Loss: 1.1190, Train: 1.0331, Val: 1.0400, Test: 1.0428
Epoch: 039, Loss: 1.0672, Train: 1.0112, Val: 1.0150, Test: 1.0196
Epoch: 040, Loss: 1.0224, Train: 1.0071, Val: 1.0070, Test: 1.0139
Epoch: 041, Loss: 1.0143, Train: 1.0195, Val: 1.0159, Test: 1.0246
Epoch: 042, Loss: 1.0395, Train: 1.0297, Val: 1.0244, Test: 1.0339
Epoch: 043, Loss: 1.0602, Train: 1.0227, Val: 1.0182, Test: 1.0274
Epoch: 044, Loss: 1.0460, Train: 1.0051, Val: 1.0033, Test: 1.0113
Epoch: 045, Loss: 1.0103, Train: 0.9932, Val: 0.9948, Test: 1.0011
Epoch: 046, Loss: 0.9864, Train: 0.9937, Val: 0.9984, Test: 1.0032
Epoch: 047, Loss: 0.9875, Train: 1.0002, Val: 1.0069, Test: 1.0105
Epoch: 048, Loss: 1.0004, Train: 1.0024, Val: 1.0099, Test: 1.0131
Epoch: 049, Loss: 1.0048, Train: 0.9962, Val: 1.0033, Test: 1.0069
Epoch: 050, Loss: 0.9925, Train: 0.9855, Val: 0.9912, Test: 0.9956
Epoch: 051, Loss: 0.9712, Train: 0.9780, Val: 0.9815, Test: 0.9871
Epoch: 052, Loss: 0.9565, Train: 0.9782, Val: 0.9792, Test: 0.9862
Epoch: 053, Loss: 0.9568, Train: 0.9818, Val: 0.9813, Test: 0.9891
Epoch: 054, Loss: 0.9640, Train: 0.9807, Val: 0.9801, Test: 0.9880
Epoch: 055, Loss: 0.9619, Train: 0.9737, Val: 0.9744, Test: 0.9818
Epoch: 056, Loss: 0.9481, Train: 0.9670, Val: 0.9697, Test: 0.9762
Epoch: 057, Loss: 0.9351, Train: 0.9652, Val: 0.9699, Test: 0.9754
Epoch: 058, Loss: 0.9315, Train: 0.9664, Val: 0.9725, Test: 0.9773
Epoch: 059, Loss: 0.9338, Train: 0.9662, Val: 0.9729, Test: 0.9775
Epoch: 060, Loss: 0.9335, Train: 0.9626, Val: 0.9690, Test: 0.9739
Epoch: 061, Loss: 0.9265, Train: 0.9575, Val: 0.9629, Test: 0.9684
Epoch: 062, Loss: 0.9167, Train: 0.9542, Val: 0.9584, Test: 0.9646
Epoch: 063, Loss: 0.9106, Train: 0.9539, Val: 0.9568, Test: 0.9637
Epoch: 064, Loss: 0.9099, Train: 0.9539, Val: 0.9561, Test: 0.9634
Epoch: 065, Loss: 0.9099, Train: 0.9516, Val: 0.9541, Test: 0.9614
Epoch: 066, Loss: 0.9056, Train: 0.9479, Val: 0.9514, Test: 0.9582
Epoch: 067, Loss: 0.8985, Train: 0.9453, Val: 0.9500, Test: 0.9563
Epoch: 068, Loss: 0.8935, Train: 0.9446, Val: 0.9504, Test: 0.9563
Epoch: 069, Loss: 0.8922, Train: 0.9442, Val: 0.9507, Test: 0.9563
Epoch: 070, Loss: 0.8914, Train: 0.9424, Val: 0.9490, Test: 0.9546
Epoch: 071, Loss: 0.8882, Train: 0.9397, Val: 0.9459, Test: 0.9517
Epoch: 072, Loss: 0.8831, Train: 0.9376, Val: 0.9432, Test: 0.9491
Epoch: 073, Loss: 0.8791, Train: 0.9367, Val: 0.9417, Test: 0.9480
Epoch: 074, Loss: 0.8775, Train: 0.9359, Val: 0.9408, Test: 0.9471
Epoch: 075, Loss: 0.8760, Train: 0.9341, Val: 0.9395, Test: 0.9456
Epoch: 076, Loss: 0.8726, Train: 0.9319, Val: 0.9383, Test: 0.9441
Epoch: 077, Loss: 0.8685, Train: 0.9307, Val: 0.9382, Test: 0.9435
Epoch: 078, Loss: 0.8662, Train: 0.9302, Val: 0.9386, Test: 0.9434
Epoch: 079, Loss: 0.8653, Train: 0.9289, Val: 0.9376, Test: 0.9422
Epoch: 080, Loss: 0.8630, Train: 0.9270, Val: 0.9353, Test: 0.9399
Epoch: 081, Loss: 0.8593, Train: 0.9257, Val: 0.9335, Test: 0.9381
Epoch: 082, Loss: 0.8570, Train: 0.9251, Val: 0.9326, Test: 0.9372
Epoch: 083, Loss: 0.8560, Train: 0.9239, Val: 0.9317, Test: 0.9361
Epoch: 084, Loss: 0.8538, Train: 0.9223, Val: 0.9308, Test: 0.9349
Epoch: 085, Loss: 0.8507, Train: 0.9211, Val: 0.9307, Test: 0.9344
Epoch: 086, Loss: 0.8486, Train: 0.9205, Val: 0.9307, Test: 0.9341
Epoch: 087, Loss: 0.8473, Train: 0.9192, Val: 0.9296, Test: 0.9329
Epoch: 088, Loss: 0.8450, Train: 0.9178, Val: 0.9280, Test: 0.9312
Epoch: 089, Loss: 0.8425, Train: 0.9169, Val: 0.9269, Test: 0.9300
Epoch: 090, Loss: 0.8408, Train: 0.9160, Val: 0.9261, Test: 0.9291
Epoch: 091, Loss: 0.8393, Train: 0.9148, Val: 0.9255, Test: 0.9282
Epoch: 092, Loss: 0.8370, Train: 0.9136, Val: 0.9253, Test: 0.9275
Epoch: 093, Loss: 0.8348, Train: 0.9127, Val: 0.9253, Test: 0.9272
Epoch: 094, Loss: 0.8333, Train: 0.9118, Val: 0.9249, Test: 0.9266
Epoch: 095, Loss: 0.8315, Train: 0.9106, Val: 0.9240, Test: 0.9254
Epoch: 096, Loss: 0.8294, Train: 0.9097, Val: 0.9231, Test: 0.9243
Epoch: 097, Loss: 0.8277, Train: 0.9089, Val: 0.9226, Test: 0.9237
Epoch: 098, Loss: 0.8263, Train: 0.9079, Val: 0.9224, Test: 0.9232
Epoch: 099, Loss: 0.8246, Train: 0.9070, Val: 0.9225, Test: 0.9230
Epoch: 100, Loss: 0.8230, Train: 0.9063, Val: 0.9227, Test: 0.9230
Epoch: 101, Loss: 0.8217, Train: 0.9056, Val: 0.9227, Test: 0.9227
Epoch: 102, Loss: 0.8204, Train: 0.9048, Val: 0.9223, Test: 0.9222
Epoch: 103, Loss: 0.8190, Train: 0.9042, Val: 0.9219, Test: 0.9216
Epoch: 104, Loss: 0.8178, Train: 0.9036, Val: 0.9217, Test: 0.9213
Epoch: 105, Loss: 0.8168, Train: 0.9030, Val: 0.9218, Test: 0.9212
Epoch: 106, Loss: 0.8157, Train: 0.9024, Val: 0.9220, Test: 0.9212
Epoch: 107, Loss: 0.8146, Train: 0.9019, Val: 0.9223, Test: 0.9212
Epoch: 108, Loss: 0.8137, Train: 0.9013, Val: 0.9224, Test: 0.9210
Epoch: 109, Loss: 0.8127, Train: 0.9008, Val: 0.9222, Test: 0.9205
Epoch: 110, Loss: 0.8117, Train: 0.9003, Val: 0.9220, Test: 0.9201
Epoch: 111, Loss: 0.8108, Train: 0.8998, Val: 0.9220, Test: 0.9198
Epoch: 112, Loss: 0.8100, Train: 0.8993, Val: 0.9221, Test: 0.9196
Epoch: 113, Loss: 0.8091, Train: 0.8989, Val: 0.9223, Test: 0.9196
Epoch: 114, Loss: 0.8083, Train: 0.8985, Val: 0.9225, Test: 0.9195
Epoch: 115, Loss: 0.8076, Train: 0.8981, Val: 0.9224, Test: 0.9193
Epoch: 116, Loss: 0.8069, Train: 0.8977, Val: 0.9222, Test: 0.9189
Epoch: 117, Loss: 0.8063, Train: 0.8974, Val: 0.9220, Test: 0.9186
Epoch: 118, Loss: 0.8057, Train: 0.8971, Val: 0.9219, Test: 0.9184
Epoch: 119, Loss: 0.8051, Train: 0.8967, Val: 0.9219, Test: 0.9184
Epoch: 120, Loss: 0.8045, Train: 0.8965, Val: 0.9220, Test: 0.9184
Epoch: 121, Loss: 0.8040, Train: 0.8962, Val: 0.9220, Test: 0.9183
Epoch: 122, Loss: 0.8035, Train: 0.8959, Val: 0.9218, Test: 0.9181
Epoch: 123, Loss: 0.8030, Train: 0.8956, Val: 0.9216, Test: 0.9178
Epoch: 124, Loss: 0.8026, Train: 0.8954, Val: 0.9214, Test: 0.9176
Epoch: 125, Loss: 0.8022, Train: 0.8952, Val: 0.9214, Test: 0.9176
Epoch: 126, Loss: 0.8017, Train: 0.8950, Val: 0.9214, Test: 0.9176
Epoch: 127, Loss: 0.8013, Train: 0.8947, Val: 0.9214, Test: 0.9175
Epoch: 128, Loss: 0.8009, Train: 0.8945, Val: 0.9213, Test: 0.9174
Epoch: 129, Loss: 0.8006, Train: 0.8943, Val: 0.9211, Test: 0.9172
Epoch: 130, Loss: 0.8002, Train: 0.8941, Val: 0.9208, Test: 0.9171
Epoch: 131, Loss: 0.7999, Train: 0.8939, Val: 0.9207, Test: 0.9169
Epoch: 132, Loss: 0.7995, Train: 0.8938, Val: 0.9207, Test: 0.9169
Epoch: 133, Loss: 0.7992, Train: 0.8936, Val: 0.9207, Test: 0.9169
Epoch: 134, Loss: 0.7989, Train: 0.8934, Val: 0.9206, Test: 0.9168
Epoch: 135, Loss: 0.7986, Train: 0.8932, Val: 0.9204, Test: 0.9166
Epoch: 136, Loss: 0.7983, Train: 0.8931, Val: 0.9202, Test: 0.9165
Epoch: 137, Loss: 0.7980, Train: 0.8929, Val: 0.9201, Test: 0.9164
Epoch: 138, Loss: 0.7977, Train: 0.8928, Val: 0.9201, Test: 0.9164
Epoch: 139, Loss: 0.7974, Train: 0.8926, Val: 0.9201, Test: 0.9164
Epoch: 140, Loss: 0.7972, Train: 0.8925, Val: 0.9200, Test: 0.9163
Epoch: 141, Loss: 0.7969, Train: 0.8923, Val: 0.9199, Test: 0.9162
Epoch: 142, Loss: 0.7967, Train: 0.8922, Val: 0.9197, Test: 0.9161
Epoch: 143, Loss: 0.7965, Train: 0.8921, Val: 0.9196, Test: 0.9159
Epoch: 144, Loss: 0.7962, Train: 0.8919, Val: 0.9198, Test: 0.9161
Epoch: 145, Loss: 0.7960, Train: 0.8918, Val: 0.9194, Test: 0.9158
Epoch: 146, Loss: 0.7958, Train: 0.8917, Val: 0.9193, Test: 0.9157
Epoch: 147, Loss: 0.7956, Train: 0.8916, Val: 0.9193, Test: 0.9157
Epoch: 148, Loss: 0.7954, Train: 0.8915, Val: 0.9192, Test: 0.9156
Epoch: 149, Loss: 0.7952, Train: 0.8914, Val: 0.9191, Test: 0.9156
Epoch: 150, Loss: 0.7950, Train: 0.8913, Val: 0.9190, Test: 0.9155
Epoch: 151, Loss: 0.7949, Train: 0.8912, Val: 0.9190, Test: 0.9155
Epoch: 152, Loss: 0.7947, Train: 0.8911, Val: 0.9189, Test: 0.9155
Epoch: 153, Loss: 0.7945, Train: 0.8910, Val: 0.9189, Test: 0.9154
Epoch: 154, Loss: 0.7944, Train: 0.8909, Val: 0.9189, Test: 0.9154
Epoch: 155, Loss: 0.7942, Train: 0.8909, Val: 0.9188, Test: 0.9153
Epoch: 156, Loss: 0.7941, Train: 0.8908, Val: 0.9188, Test: 0.9153
Epoch: 157, Loss: 0.7939, Train: 0.8907, Val: 0.9188, Test: 0.9152
Epoch: 158, Loss: 0.7938, Train: 0.8906, Val: 0.9188, Test: 0.9152
Epoch: 159, Loss: 0.7936, Train: 0.8905, Val: 0.9188, Test: 0.9151
Epoch: 160, Loss: 0.7935, Train: 0.8905, Val: 0.9188, Test: 0.9151
Epoch: 161, Loss: 0.7934, Train: 0.8904, Val: 0.9188, Test: 0.9151
Epoch: 162, Loss: 0.7932, Train: 0.8903, Val: 0.9188, Test: 0.9151
Epoch: 163, Loss: 0.7931, Train: 0.8903, Val: 0.9188, Test: 0.9151
Epoch: 164, Loss: 0.7930, Train: 0.8902, Val: 0.9188, Test: 0.9150
Epoch: 165, Loss: 0.7929, Train: 0.8901, Val: 0.9189, Test: 0.9150
Epoch: 166, Loss: 0.7928, Train: 0.8901, Val: 0.9189, Test: 0.9150
Epoch: 167, Loss: 0.7926, Train: 0.8900, Val: 0.9189, Test: 0.9150
Epoch: 168, Loss: 0.7925, Train: 0.8899, Val: 0.9189, Test: 0.9150
Epoch: 169, Loss: 0.7924, Train: 0.8899, Val: 0.9189, Test: 0.9150
Epoch: 170, Loss: 0.7923, Train: 0.8898, Val: 0.9190, Test: 0.9150
Epoch: 171, Loss: 0.7922, Train: 0.8898, Val: 0.9190, Test: 0.9149
Epoch: 172, Loss: 0.7921, Train: 0.8897, Val: 0.9190, Test: 0.9149
Epoch: 173, Loss: 0.7920, Train: 0.8896, Val: 0.9190, Test: 0.9148
Epoch: 174, Loss: 0.7919, Train: 0.8896, Val: 0.9190, Test: 0.9148
Epoch: 175, Loss: 0.7918, Train: 0.8895, Val: 0.9190, Test: 0.9148
Epoch: 176, Loss: 0.7917, Train: 0.8895, Val: 0.9190, Test: 0.9147
Epoch: 177, Loss: 0.7916, Train: 0.8894, Val: 0.9190, Test: 0.9147
Epoch: 178, Loss: 0.7916, Train: 0.8894, Val: 0.9190, Test: 0.9147
Epoch: 179, Loss: 0.7915, Train: 0.8893, Val: 0.9191, Test: 0.9147
Epoch: 180, Loss: 0.7914, Train: 0.8893, Val: 0.9191, Test: 0.9147
Epoch: 181, Loss: 0.7913, Train: 0.8893, Val: 0.9191, Test: 0.9147
Epoch: 182, Loss: 0.7912, Train: 0.8892, Val: 0.9191, Test: 0.9146
Epoch: 183, Loss: 0.7911, Train: 0.8892, Val: 0.9191, Test: 0.9146
Epoch: 184, Loss: 0.7910, Train: 0.8891, Val: 0.9191, Test: 0.9146
Epoch: 185, Loss: 0.7910, Train: 0.8891, Val: 0.9192, Test: 0.9146
Epoch: 186, Loss: 0.7909, Train: 0.8890, Val: 0.9192, Test: 0.9146
Epoch: 187, Loss: 0.7908, Train: 0.8890, Val: 0.9192, Test: 0.9146
Epoch: 188, Loss: 0.7907, Train: 0.8889, Val: 0.9192, Test: 0.9145
Epoch: 189, Loss: 0.7907, Train: 0.8889, Val: 0.9192, Test: 0.9145
Epoch: 190, Loss: 0.7906, Train: 0.8889, Val: 0.9192, Test: 0.9145
Epoch: 191, Loss: 0.7905, Train: 0.8888, Val: 0.9192, Test: 0.9145
Epoch: 192, Loss: 0.7905, Train: 0.8888, Val: 0.9192, Test: 0.9144
Epoch: 193, Loss: 0.7904, Train: 0.8888, Val: 0.9192, Test: 0.9144
Epoch: 194, Loss: 0.7903, Train: 0.8887, Val: 0.9192, Test: 0.9144
Epoch: 195, Loss: 0.7903, Train: 0.8887, Val: 0.9192, Test: 0.9144
Epoch: 196, Loss: 0.7902, Train: 0.8886, Val: 0.9192, Test: 0.9144
Epoch: 197, Loss: 0.7901, Train: 0.8886, Val: 0.9193, Test: 0.9143
Epoch: 198, Loss: 0.7901, Train: 0.8886, Val: 0.9193, Test: 0.9143
Epoch: 199, Loss: 0.7900, Train: 0.8885, Val: 0.9193, Test: 0.9143
Epoch: 200, Loss: 0.7899, Train: 0.8885, Val: 0.9193, Test: 0.9143
Epoch: 201, Loss: 0.7899, Train: 0.8885, Val: 0.9192, Test: 0.9143
Epoch: 202, Loss: 0.7898, Train: 0.8884, Val: 0.9193, Test: 0.9143
Epoch: 203, Loss: 0.7898, Train: 0.8884, Val: 0.9193, Test: 0.9143
Epoch: 204, Loss: 0.7897, Train: 0.8884, Val: 0.9193, Test: 0.9143
Epoch: 205, Loss: 0.7896, Train: 0.8883, Val: 0.9193, Test: 0.9143
Epoch: 206, Loss: 0.7896, Train: 0.8883, Val: 0.9193, Test: 0.9142
Epoch: 207, Loss: 0.7895, Train: 0.8883, Val: 0.9193, Test: 0.9143
Epoch: 208, Loss: 0.7895, Train: 0.8882, Val: 0.9193, Test: 0.9142
Epoch: 209, Loss: 0.7894, Train: 0.8882, Val: 0.9193, Test: 0.9142
Epoch: 210, Loss: 0.7894, Train: 0.8882, Val: 0.9193, Test: 0.9142
Epoch: 211, Loss: 0.7893, Train: 0.8882, Val: 0.9193, Test: 0.9142
Epoch: 212, Loss: 0.7893, Train: 0.8881, Val: 0.9193, Test: 0.9142
Epoch: 213, Loss: 0.7892, Train: 0.8881, Val: 0.9193, Test: 0.9142
Epoch: 214, Loss: 0.7892, Train: 0.8881, Val: 0.9193, Test: 0.9142
Epoch: 215, Loss: 0.7891, Train: 0.8880, Val: 0.9194, Test: 0.9141
Epoch: 216, Loss: 0.7891, Train: 0.8880, Val: 0.9194, Test: 0.9141
Epoch: 217, Loss: 0.7890, Train: 0.8880, Val: 0.9194, Test: 0.9141
Epoch: 218, Loss: 0.7890, Train: 0.8880, Val: 0.9194, Test: 0.9141
Epoch: 219, Loss: 0.7889, Train: 0.8879, Val: 0.9194, Test: 0.9141
Epoch: 220, Loss: 0.7889, Train: 0.8879, Val: 0.9194, Test: 0.9141
Epoch: 221, Loss: 0.7888, Train: 0.8879, Val: 0.9194, Test: 0.9141
Epoch: 222, Loss: 0.7888, Train: 0.8879, Val: 0.9194, Test: 0.9141
Epoch: 223, Loss: 0.7887, Train: 0.8878, Val: 0.9194, Test: 0.9141
Epoch: 224, Loss: 0.7887, Train: 0.8878, Val: 0.9195, Test: 0.9141
Epoch: 225, Loss: 0.7887, Train: 0.8878, Val: 0.9194, Test: 0.9141
Epoch: 226, Loss: 0.7886, Train: 0.8878, Val: 0.9194, Test: 0.9141
Epoch: 227, Loss: 0.7886, Train: 0.8877, Val: 0.9195, Test: 0.9141
Epoch: 228, Loss: 0.7885, Train: 0.8877, Val: 0.9195, Test: 0.9141
Epoch: 229, Loss: 0.7885, Train: 0.8877, Val: 0.9195, Test: 0.9140
Epoch: 230, Loss: 0.7884, Train: 0.8877, Val: 0.9195, Test: 0.9140
Epoch: 231, Loss: 0.7884, Train: 0.8876, Val: 0.9195, Test: 0.9140
Epoch: 232, Loss: 0.7884, Train: 0.8876, Val: 0.9195, Test: 0.9140
Epoch: 233, Loss: 0.7883, Train: 0.8876, Val: 0.9195, Test: 0.9140
Epoch: 234, Loss: 0.7883, Train: 0.8876, Val: 0.9195, Test: 0.9140
Epoch: 235, Loss: 0.7882, Train: 0.8876, Val: 0.9195, Test: 0.9140
Epoch: 236, Loss: 0.7882, Train: 0.8875, Val: 0.9196, Test: 0.9140
Epoch: 237, Loss: 0.7882, Train: 0.8875, Val: 0.9196, Test: 0.9140
Epoch: 238, Loss: 0.7881, Train: 0.8875, Val: 0.9195, Test: 0.9140
Epoch: 239, Loss: 0.7881, Train: 0.8875, Val: 0.9196, Test: 0.9140
Epoch: 240, Loss: 0.7880, Train: 0.8874, Val: 0.9196, Test: 0.9140
Epoch: 241, Loss: 0.7880, Train: 0.8874, Val: 0.9196, Test: 0.9140
Epoch: 242, Loss: 0.7880, Train: 0.8874, Val: 0.9196, Test: 0.9140
Epoch: 243, Loss: 0.7879, Train: 0.8874, Val: 0.9196, Test: 0.9140
Epoch: 244, Loss: 0.7879, Train: 0.8874, Val: 0.9196, Test: 0.9140
Epoch: 245, Loss: 0.7879, Train: 0.8873, Val: 0.9196, Test: 0.9140
Epoch: 246, Loss: 0.7878, Train: 0.8873, Val: 0.9196, Test: 0.9140
Epoch: 247, Loss: 0.7878, Train: 0.8873, Val: 0.9196, Test: 0.9140
Epoch: 248, Loss: 0.7877, Train: 0.8873, Val: 0.9196, Test: 0.9140
Epoch: 249, Loss: 0.7877, Train: 0.8873, Val: 0.9196, Test: 0.9139
Epoch: 250, Loss: 0.7877, Train: 0.8872, Val: 0.9196, Test: 0.9140
Epoch: 251, Loss: 0.7876, Train: 0.8872, Val: 0.9196, Test: 0.9140
Epoch: 252, Loss: 0.7876, Train: 0.8872, Val: 0.9196, Test: 0.9140
Epoch: 253, Loss: 0.7876, Train: 0.8872, Val: 0.9196, Test: 0.9140
Epoch: 254, Loss: 0.7875, Train: 0.8872, Val: 0.9196, Test: 0.9140
Epoch: 255, Loss: 0.7875, Train: 0.8871, Val: 0.9196, Test: 0.9139
Epoch: 256, Loss: 0.7875, Train: 0.8871, Val: 0.9196, Test: 0.9140
Epoch: 257, Loss: 0.7874, Train: 0.8871, Val: 0.9196, Test: 0.9140
Epoch: 258, Loss: 0.7874, Train: 0.8871, Val: 0.9196, Test: 0.9139
Epoch: 259, Loss: 0.7874, Train: 0.8871, Val: 0.9196, Test: 0.9139
Epoch: 260, Loss: 0.7873, Train: 0.8870, Val: 0.9196, Test: 0.9139
Epoch: 261, Loss: 0.7873, Train: 0.8870, Val: 0.9196, Test: 0.9140
Epoch: 262, Loss: 0.7873, Train: 0.8870, Val: 0.9196, Test: 0.9139
Epoch: 263, Loss: 0.7872, Train: 0.8870, Val: 0.9196, Test: 0.9139
Epoch: 264, Loss: 0.7872, Train: 0.8870, Val: 0.9196, Test: 0.9139
Epoch: 265, Loss: 0.7872, Train: 0.8870, Val: 0.9196, Test: 0.9139
Epoch: 266, Loss: 0.7871, Train: 0.8869, Val: 0.9197, Test: 0.9139
Epoch: 267, Loss: 0.7871, Train: 0.8869, Val: 0.9197, Test: 0.9139
Epoch: 268, Loss: 0.7871, Train: 0.8869, Val: 0.9196, Test: 0.9139
Epoch: 269, Loss: 0.7871, Train: 0.8869, Val: 0.9196, Test: 0.9139
Epoch: 270, Loss: 0.7870, Train: 0.8869, Val: 0.9197, Test: 0.9139
Epoch: 271, Loss: 0.7870, Train: 0.8869, Val: 0.9197, Test: 0.9139
Epoch: 272, Loss: 0.7870, Train: 0.8868, Val: 0.9197, Test: 0.9139
Epoch: 273, Loss: 0.7869, Train: 0.8868, Val: 0.9197, Test: 0.9139
Epoch: 274, Loss: 0.7869, Train: 0.8868, Val: 0.9197, Test: 0.9139
Epoch: 275, Loss: 0.7869, Train: 0.8868, Val: 0.9197, Test: 0.9139
Epoch: 276, Loss: 0.7868, Train: 0.8868, Val: 0.9197, Test: 0.9139
Epoch: 277, Loss: 0.7868, Train: 0.8868, Val: 0.9197, Test: 0.9139
Epoch: 278, Loss: 0.7868, Train: 0.8867, Val: 0.9197, Test: 0.9139
Epoch: 279, Loss: 0.7867, Train: 0.8867, Val: 0.9197, Test: 0.9139
Epoch: 280, Loss: 0.7867, Train: 0.8867, Val: 0.9197, Test: 0.9139
Epoch: 281, Loss: 0.7867, Train: 0.8867, Val: 0.9197, Test: 0.9139
Epoch: 282, Loss: 0.7867, Train: 0.8867, Val: 0.9197, Test: 0.9139
Epoch: 283, Loss: 0.7866, Train: 0.8867, Val: 0.9197, Test: 0.9139
Epoch: 284, Loss: 0.7866, Train: 0.8866, Val: 0.9196, Test: 0.9139
Epoch: 285, Loss: 0.7866, Train: 0.8866, Val: 0.9196, Test: 0.9139
Epoch: 286, Loss: 0.7865, Train: 0.8866, Val: 0.9197, Test: 0.9139
Epoch: 287, Loss: 0.7865, Train: 0.8866, Val: 0.9196, Test: 0.9138
Epoch: 288, Loss: 0.7865, Train: 0.8866, Val: 0.9197, Test: 0.9139
Epoch: 289, Loss: 0.7865, Train: 0.8866, Val: 0.9196, Test: 0.9139
Epoch: 290, Loss: 0.7864, Train: 0.8865, Val: 0.9196, Test: 0.9138
Epoch: 291, Loss: 0.7864, Train: 0.8865, Val: 0.9196, Test: 0.9138
Epoch: 292, Loss: 0.7864, Train: 0.8865, Val: 0.9196, Test: 0.9138
Epoch: 293, Loss: 0.7863, Train: 0.8865, Val: 0.9196, Test: 0.9138
Epoch: 294, Loss: 0.7863, Train: 0.8865, Val: 0.9196, Test: 0.9138
Epoch: 295, Loss: 0.7863, Train: 0.8865, Val: 0.9196, Test: 0.9138
Epoch: 296, Loss: 0.7863, Train: 0.8864, Val: 0.9196, Test: 0.9138
Epoch: 297, Loss: 0.7862, Train: 0.8864, Val: 0.9196, Test: 0.9138
Epoch: 298, Loss: 0.7862, Train: 0.8864, Val: 0.9196, Test: 0.9138
Epoch: 299, Loss: 0.7862, Train: 0.8864, Val: 0.9196, Test: 0.9138
Epoch: 300, Loss: 0.7861, Train: 0.8864, Val: 0.9195, Test: 0.9138