图神经网络实战——动态图:时空图神经网络(Spatiotemporal GNNs)

1,917 阅读1小时+

本章内容

  • 在深度学习模型中引入记忆
  • 理解使用图神经网络建模时间关系的不同方式
  • 实现动态图神经网络
  • 评估时序图神经网络模型

到目前为止,我们的所有模型和数据都只是时间中的单一快照。实际上,世界是动态的,处于不断变化之中。物体可以在我们的眼前物理移动,沿着一条轨迹行进,我们能够根据这些观察到的轨迹预测它们的未来位置。交通流、天气模式以及疾病在人群网络中的传播等都是可以通过时空图来建模,从而获得更多信息的例子,而不是使用静态图。

我们今天构建的模型,在实际应用中可能很快失去性能和准确性。这是任何深度学习(和机器学习)模型的固有问题,被称为分布外(OOD)泛化问题,即模型在面对完全未见过的数据时的泛化能力。

在本章中,我们将考虑如何构建适合动态事件的模型。虽然这并不意味着它们能够处理OOD数据,但我们的动态模型将能够利用近期的过去对未来的未见事件进行预测。

为了构建我们的动态图学习模型,我们将考虑姿势估计问题。姿势估计涉及那些预测身体(人类、动物或机器人)随时间变化的类问题。在本章中,我们将考虑一个行走的身体,并构建多个模型,从一系列视频帧中学习如何预测下一步。为了做到这一点,我们将首先更详细地解释问题,并讲解如何将其理解为一个关系问题,然后再深入探讨图学习方法如何处理这个问题。与本书的其他部分一样,进一步的技术细节将在本章末尾的第6.5节中讲解。

我们将使用本书中已经覆盖的大部分内容。如果你跳到了这一章,确保你对“基于已学内容构建”边栏中描述的概念有充分的理解。

:本章的代码可以在GitHub仓库中以笔记本形式找到(mng.bz/4a8D)。

基于已学内容构建

为了在我们的图神经网络(GNN)中引入时间更新,我们可以基于前几章中学到的一些概念进行扩展。为了快速回顾,我们总结了每章中的一些主要重要特性:

  • 消息传递—在第2章中,你学到的GNN学习关系数据的主要方法是通过将消息传递与人工神经网络结合使用。GNN的每一层可以理解为一次消息传递的步骤。
  • 图卷积网络(GCNs) —在第3章中,你看到消息传递本身可以理解为卷积操作的关系形式(就像卷积神经网络[CNNs]中的卷积操作一样),这也是GCNs的核心思想。消息还可以通过只采样最近邻的子集在邻域之间进行平均,这就是GraphSAGE的使用方式,可以显著减少计算量。
  • 注意力机制—在第4章中,我们展示了消息传递的聚合函数不需要仅限于求和、平均或最大值操作(尽管该操作必须是置换不变的)。注意力机制允许在训练过程中学习加权,从而提供更灵活的消息传递聚合函数。使用图注意力网络(GAT)是向消息传递中添加注意力的基本形式。
  • 生成模型—而判别模型旨在学习数据类别之间的分隔,生成模型则试图学习数据生成过程的潜在机制。自动编码器是设计生成模型的最流行框架之一,其中数据通过神经网络瓶颈进行传递,创建数据的低维表示,也称为潜在空间。这些通常实现为图自动编码器(GAEs)或变分图自动编码器(VGAEs),正如我们在第5章中讨论的那样。

6.1 时间模型:通过时间建立关系

几乎每个数据问题在某种程度上也是一个动态问题。在许多情况下,我们可以忽略时间变化,构建适合我们收集到的数据快照的模型。例如,图像分割方法通常不会考虑使用视频镜头来训练模型。

在第3章中,我们使用图卷积网络(GCN)来预测适合推荐给客户的产品,利用客户-购买者网络中的数据。我们使用了一个在几年的时间里收集的玩具数据集。然而,实际上,我们通常会有持续的数据流,并希望做出最新的预测,以考虑客户和文化习惯的变化。类似地,当我们将图注意力网络(GAT)应用于欺诈检测问题时,我们使用的数据是一个在几年的时间里收集的单一快照的财务记录。然而,我们在模型中并未考虑财务行为如何随时间变化。再一次,我们很可能希望利用这些信息来预测个人消费行为的突然变化,帮助我们检测欺诈活动。

这些只是我们每天面临的许多不同动态问题中的一小部分(见图6.1)。图神经网络(GNNs)独特之处在于它们可以同时建模动态变化和关系变化。这一点非常重要,因为我们周围运行的许多网络也在随时间变化。例如,社交网络。我们的友谊随着时间的推移发生变化、成熟,甚至(不幸或幸运地!)逐渐减弱。我们可能与同事或朋友的朋友成为更亲密的朋友,而与家乡的朋友见面的频率则减少。为社交网络做预测时需要考虑到这一点。

另一个例子是,我们常常根据对道路、交通模式以及自己有多匆忙的了解来预测应该走哪条路,何时能到达。动态图神经网络也可以用来利用这些数据,将道路网络视为图,并对该网络如何变化进行时间预测。最后,我们可以考虑预测两个或更多物体如何一起移动,也就是估计它们的未来轨迹。虽然这看起来不如交朋友或准时上班那么有用,但预测相互作用物体的轨迹,例如分子、细胞、物体甚至是星星,对于许多科学领域以及机器人规划至关重要。同样,动态图神经网络可以帮助我们预测这些轨迹,并推断出解释这些轨迹的新方程或规则。

image.png

这些例子仅仅是我们需要建模时间变化应用的冰山一角。事实上,我们相信你能够想到更多的例子。鉴于将关系学习与时间学习相结合的重要性,我们将介绍三种不同的构建动态模型的方法,其中两种使用GNN:递归神经网络(RNN)模型、图注意力网络(GAT)模型和神经关系推理(NRI)模型。我们将构建“学习走路”的机器学习模型,通过估计人体姿势随时间的变化来实现。这些模型常常应用于医疗咨询、远程家居安防服务、电影制作等领域。这些模型也是一个很好的玩具问题,可以让我们在“学会走路”之前先掌握基础。因此,首先让我们了解一下数据,并构建我们的第一个基准模型。

6.2 问题定义:姿势估计

在本章中,我们将通过一组数据解决一个“动态关系”问题:行走人体的预处理分割数据。这个数据集非常适合探索这些技术,因为一个移动的身体是一个典型的相互作用系统的例子:我们的脚因膝盖的运动而移动,膝盖因腿部的运动而移动,手臂和躯干也会一起移动。这意味着我们的模型问题中包含了时间因素。

简而言之,我们的姿势估计问题是关于路径预测的。更准确地说,我们希望知道,例如,脚会在经历一段时间的身体运动后,移动到哪里。这种物体追踪是我们每天都会做的事情,例如,当我们做运动、接住掉落的物体,或是观看电视节目时。我们在小时候就学会了这一技能,且常常理所当然。然而,正如你将看到的,在时空图神经网络出现之前,教机器执行这种物体追踪一直是一个重大挑战。

我们将在路径预测中使用的技能对于许多其他任务也非常重要。预测未来事件是有用的,尤其当我们想预测客户的下一个购买,或是根据地理空间数据了解天气模式将如何变化时。

我们将使用卡内基梅隆大学(CMU)的动作捕捉数据库(mocap.cs.cmu.edu/),该数据库包含了许多不同动态姿势的例子,包括行走、跑步、跳跃和运动动作的执行,以及多个人互动的情况[1]。在本章中,我们将使用该数据集中第35号受试者行走的同一数据集。在每个时间步,受试者有41个传感器,每个传感器跟踪一个关节,从脚趾到颈部。图6.2展示了来自这个数据库的数据示例。这些传感器跟踪身体一部分的运动快照。在本章中,我们不会跟踪整个运动,而只考虑运动的一小部分。我们将使用前49帧作为训练和验证数据集,99帧作为测试集。总共有31个不同的受试者行走的例子。我们将在下一节中讨论数据结构的更多细节。

image.png

6.2.1 问题设置

我们的目标是预测所有独立关节的动态。显然,我们可以将其构建为一个图,因为所有的关节通过边连接,如之前在图6.2中所示。因此,使用GNN来解决这个问题是有意义的。然而,我们首先将对比另一种方法,该方法不考虑图数据,用于基准测试我们的GNN模型。

下载数据

我们已经在代码库中提供了下载和预处理数据的步骤。数据包含在一个压缩文件中,其中每个不同的试验保存为高级系统格式(.asf)文件。这些.asf文件基本上只是包含每个传感器标签及其在每个时间步的xyz坐标的文本文件。在下面的代码示例中,我们展示了一个文本片段。

代码示例 6.1 示例传感器数据文本文件

   1
   root 4.40047 17.8934 -21.0986 -0.943965 -8.37963 -7.42612
   lowerback 11.505 1.60479 4.40928
   upperback 0.47251 2.84449 2.26157
   thorax -5.8636 1.30424 -0.569129
   lowerneck -15.9456 -3.55911 -2.36067
   upperneck 19.9076 -4.57025 1.03589

在这里,第一个数字是帧编号,root是特定于传感器的,可以忽略。lowerbackupperbackthoraxlowerneckupperneck 表示传感器的位置。总共有31个传感器映射了一个人行走的运动。为了将这些传感器数据转换为轨迹,我们需要计算每个传感器的位置变化。这是一项相当复杂的任务,因为我们需要考虑每帧之间的平移运动和角度旋转。这里,我们将使用与NRI论文[2]中相同的数据文件。我们可以使用这些文件绘制每个传感器在x、y和z轴的轨迹,或者查看传感器在二维空间中的运动,以帮助我们理解整个身体的运动。图6.3展示了这些例子,我们专注于一个脚部传感器在x、y和z轴上的运动,以及身体随时间的整体运动(传感器以实心黑星标出)。

image.png

除了空间数据,我们还可以计算速度数据。这些数据作为单独的文件提供,每个电影帧都有一个速度数据文件。图6.4展示了速度数据变化的示例。如你所见,速度数据的变化范围较小。空间数据和速度数据将作为我们机器学习问题中的特征。在这里,我们现在有六个特征,涵盖了50帧中的每一个传感器,并且跨越了33个不同的试验。我们可以将其理解为一个多变量时间序列问题。我们试图预测一个六维(三个空间维度和三个速度维度)物体(每个传感器)的未来演变。我们的第一种方法将这些特征视为独立的,试图基于过去的传感器数据预测未来的位置和速度。然后,我们将转而将其视为一个图,在这个图中,我们可以将所有传感器连接在一起。

image.png

目前,这是一个关系问题,但我们只考虑了节点数据,而没有考虑边数据。在没有边数据的情况下,必须小心不要做出过多的假设。例如,如果我们选择基于节点之间的距离来连接节点,那么最终可能会得到一个看起来非常奇怪的骨架,如图6.5所示。幸运的是,我们也有边数据,这些数据是使用CMU数据集构建的,并包含在提供的数据中。这提醒我们,GNN的强大程度取决于它们所训练的图结构,我们必须小心确保图结构的正确性。然而,如果完全缺乏边数据,我们可以尝试从节点数据本身推断出边数据。虽然我们在这里不会这么做,但需要注意,我们将使用的NRI模型具备这种能力。

image.png

现在我们已经加载了所有数据。总共,我们有三个数据集(训练、验证、测试),每个数据集包含31个单独的传感器位置。每个传感器包含六个特征(空间坐标),并通过一个在时间上保持不变的邻接矩阵连接。传感器图是无向图,边是无权重的。训练集和验证集包含49帧,而测试集包含99帧。

6.2.2 带有记忆的模型构建

现在问题已经定义,数据也已加载,我们可以考虑如何解决预测关节动态的问题。首先,我们需要思考其核心目标是什么。从本质上讲,我们将涉及序列预测,就像手机的自动补全或搜索工具一样。这类问题通常使用网络模型来处理,如Transformer,在第四章中我们介绍了其使用的注意力机制。然而,在注意力网络出现之前,许多深度学习从业者通过向模型中引入记忆来处理序列预测任务[3]。这具有直观的意义:如果我们想要预测未来,我们需要记住过去。

让我们构建一个简单的模型,使用过去的事件预测所有个体传感器的下一个位置。本质上,这意味着我们将构建一个预测节点位置的模型,而没有边数据。我们将尝试的示例如图6.6所示。在这里,我们将首先对数据进行预处理,并准备将数据传递给一个能够预测数据随时间演变的模型。这使我们能够在给定一些输入帧的情况下,预测姿态的变化。

image.png

为了将记忆引入到我们的神经网络中,我们将首先考虑递归神经网络(RNN)。与卷积神经网络和注意力神经网络类似,RNN 是一个广泛的架构类别,是研究人员和实践者的基础工具。有关 RNN 的更多信息,请参见《Machine Learning with TensorFlow》(Manning, 2020, mng.bz/VVOW)。RNN可以被视为多个独立的网络相互连接。这些重复的子网络允许将过去的信息“记住”,并使过去数据的影响影响到未来的预测。在初始化之后,每个子网络接受输入数据以及上一个子网络的输出,这些信息用来进行新的预测。换句话说,每个子网络使用来自最近过去的输入和信息来推断数据。然而,原始的 RNN 只能记住前一个步骤。它们具有非常短期的记忆。为了增强过去对未来的影响,我们需要更强的记忆机制。

长短期记忆(LSTM)网络是另一种非常流行的神经网络架构,用于建模和预测时间序列或顺序信息。这些网络是 RNN 的特例,同样将多个子网络连接在一起。不同之处在于,LSTM 在子网络结构中引入了更复杂的依赖关系。LSTM 特别适用于顺序数据,因为它们解决了 RNN 中观察到的梯度消失问题。简单来说,梯度消失是指我们用来训练神经网络的梯度在梯度下降过程中趋近于零。尤其当我们训练一个具有许多层的 RNN 时,这种情况尤为常见。(我们在此不深入探讨原因,但如果你感兴趣,可以阅读《Deep Learning with Python》(Manning, 2024, mng.bz/xKag)获取更多信息。)

门控循环单元(GRU)网络通过允许将新的信息添加到关于最近过去的记忆存储中,解决了梯度消失的问题。这是通过一个门控结构实现的,其中模型架构中的门控帮助控制信息流动。这些门控还为我们构建和调整神经网络的方式增加了新的设计元素。在这里我们不考虑 LSTM,因为它超出了本书的范围,但我们再次建议你查阅《Deep Learning with Python》(Manning, 2024, mng.bz/xKag)以获取更多信息。

构建递归神经网络

现在让我们来看一下如何使用 RNN 来预测身体传感器随时间变化的轨迹,这将作为我们未来性能提升的基准之一。我们不会深入探讨 RNN 和 GRU 架构的细节,但在本章的 6.5 节中提供了更多信息。

这个模型的想法是,我们的 RNN 将预测传感器的未来位置,而不考虑关系数据。当我们开始引入图模型时,我们将看到如何改进这一点。

我们将使用深度学习的标准训练循环,如图 6.7 所示。一旦我们定义了模型并定义了训练和测试循环,就可以使用这些来训练并测试模型。和往常一样,我们将保持训练数据和测试数据完全分开,并包括一个验证集来确保我们的模型在训练过程中不会过拟合。

image.png

这里使用的训练循环是相当标准的,所以我们首先对其进行描述。在列表 6.2 中显示的训练循环定义中,我们遵循了与前几章相同的约定,在固定的周期数内循环进行模型预测和损失更新。在这里,我们的损失将包含在我们的标准函数中,我们将其定义为简单的均方误差(MSE)损失。我们将使用学习率调度器,在验证损失开始平稳时降低学习率参数。我们初始化最佳损失为无穷大,当验证损失小于最佳损失时,我们将在 N 步之后降低学习率。

列表 6.2 训练循环

num_epochs = 200  
train_losses = []
valid_losses = []

pbar = tqdm(range(num_epochs))

for epoch in pbar:

    train_loss = 0.0  #1
    valid_loss = 0.0  #1

    modelRNN.train()  #2
    for i, (inputs, labels) in enumerate(trainloader):
        inputs = inputs.to(device)
        labels = labels.to(device)

    optimizer.zero_grad()  #3

    outputs = modelRNN(inputs)  #4
    loss = criterion(outputs, labels)  #4
    loss.backward()  #4
    optimizer.step()  #4

    train_loss += loss.item() * inputs.size(0)  #5

    modelRNN.eval()   #6
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(validloader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = modelRNN(inputs) 
            loss = criterion(outputs, labels)
            valid_loss += loss.item() * inputs.size(0)

    if valid_loss < best_loss:  #7
        best_loss = valid_loss
        counter = 0
    else:
        counter += 1

    scheduler.step(best_loss)  #8

    if counter == early_stop:
        print(f"\n\nEarly stopping initiated, no change after {early_stop} steps")
        break

    train_loss = train_loss / len(trainloader.dataset)  #9
    valid_loss = valid_loss / len(validloader.dataset)  #9

    train_losses.append(train_loss)  #9
    valid_losses.append(valid_loss)  #9

注释:

  • #1 初始化损失和准确度变量
  • #2 开始训练循环
  • #3 将参数梯度归零
  • #4 前向 + 反向传播 + 优化
  • #5 更新训练损失,乘以当前小批量的样本数量
  • #6 开始验证循环
  • #7 检查是否提前停止
  • #8 更新学习率调度器
  • #9 计算并存储损失

两层网络(使用列表 6.3 中的训练循环)被训练用于特定任务。对于 RNN 和 GRU,数据的格式将是单独的实验或视频、帧时间戳、传感器数量和传感器的特征。通过提供按时间分解的数据,模型能够利用时间序列特征进行学习。在这里,我们使用 RNN 来预测每个传感器的未来位置,给定前 40 帧。对于所有计算,我们将使用最小-最大缩放对基于节点特征(位置和速度)进行数据归一化。

完成训练循环后,我们测试我们的网络。和往常一样,我们不想更新网络的参数,因此确保没有反向传播的梯度(通过选择 torch.no_grad())。注意,我们选择 40 的序列长度,以便测试循环能够看到前 40 帧,然后尝试推断最后 10 帧。

列表 6.3 测试循环

model.eval()   #1
predictions = [] 
test_losses = [] 
seq_len = 40 

with torch.no_grad():
    for i, (inputs, targets) in enumerate(testloader):
        inputs = inputs.to(device)
        targets = targets.to(device)

        preds = []
        for _ in range(seq_len):
            output = model(inputs)
            preds.append(output)

        inputs = torch.cat([inputs[:, 1:], output.unsqueeze(1)], dim=1)  #2

    preds = torch.cat(preds, dim=1)  #3
    loss = criterion(preds, targets)  #3
    test_losses.append(loss.item())  #3

    predictions.append(preds.detach().cpu().numpy())

predictions = np.concatenate(predictions, axis=0)  #4
test_loss = np.mean(test_losses)  #5

注释:

  • #1 设置模型为评估模式
  • #2 更新下一个预测的输入
  • #3 计算该序列的损失
  • #4 将预测转换为 NumPy 数组以便于操作
  • #5 计算平均测试损失

一旦我们的模型定义完成,我们就可以使用列表 6.3 中给出的训练循环来训练模型。在这一点上,你可能会想知道我们将如何修改训练循环以正确处理回传时的时间元素。好消息是,这由 PyTorch 自动处理。我们发现 RNN 模型能够以 70% 的准确率预测验证数据的未来位置,而测试数据的准确率为 60%。

我们还尝试了 GRU 模型来预测未来的步伐,发现这个模型能够在验证数据上获得 75% 的准确率。这虽然较低,但考虑到模型的简单性和我们提供给它的信息量,它的表现还算不错。然而,当我们在测试数据上评估模型性能时,结果显示准确率降到了 65%。我们在图 6.8 中展示了一些模型输出的示例。显然,模型很快退化,估计的姿势位置开始大幅波动。为了更好的准确性,我们需要在姿势数据中使用一些关系归纳偏置。

image.png

6.3 动态图神经网络

为了预测图的未来演变,我们需要重新组织数据,以考虑时间序列数据。具体来说,动态图神经网络(Dynamic GNNs)连接图演化的不同时间快照,并学习预测未来的演变[4–6]。一种方法是将它们合并成一个单一的图。这个时间图现在包含了每个时间步的数据,以及作为节点的时间边连接的时间关系。我们将首先通过一种简单的方法来处理姿态估计任务,考虑如何将我们的时间数据组合成一个大的图,并通过遮蔽感兴趣的节点来预测未来的演变。我们将使用第三章中介绍的相同GAT网络。然后,在第6.4节中,我们将展示通过编码每个时间快照的图,并使用变分自编码器(VAE)和RNN的组合来预测演变的另一种解决姿态估计问题的方法,即NRI方法[2]。

6.3.1 用于动态图的图注意力网络(GAT)

我们将探讨如何将我们的姿态估计问题转化为图形问题。为此,我们需要构建一个考虑时间信息的邻接矩阵。首先,我们需要将数据加载为PyTorch Geometric(PyG)数据对象。我们将使用与训练RNN时相同的位置信息和速度数据。不同之处在于,我们将构建一个包含所有数据的单一图。列表6.4中的代码片段展示了我们如何初始化数据集。我们传递了位置和速度数据的路径,以及边缘数据的位置。我们还传递了是否需要转换数据的标志,以及我们将要预测的掩码和窗口大小。

列表6.4 加载数据作为图形

class PoseDataset(Dataset):
    def __init__(self, loc_path, 
                      vel_path, 
                      edge_path, 
                      mask_path, 
                      mask_size, 
                      transform=True):
        self.locations = np.load(loc_path)  #1
        self.velocities = np.load(vel_path)  #1
        self.edges = np.load(edge_path)

        self.transform = transform
        self.mask_size = mask_size  #2
        self.window_size = self.locations.shape[1] - self.mask_size  #3
  • #1 从.npy文件加载数据
  • #2 确定掩码大小
  • #3 确定窗口大小

对于所有的数据集对象,我们需要在类中定义一个get方法,来描述如何获取这些数据,这在列表6.5中展示。该方法将位置和速度数据组合成节点特征。我们还提供了一个选项来使用normalize_array函数转换数据。

列表6.5 设置节点特征使用位置和速度数据

def __getitem__(self, idx):
    nodes = np.concatenate((self.locations[idx], self.velocities[idx]), axis=2)  #1
    nodes = nodes.reshape(-1, nodes.shape[-1])  #2

    if self.transform:  #3
        nodes, node_min, node_max = normalize_array(nodes) 

    total_timesteps = self.window_size + self.mask_size  #4
    edge_index = np.repeat(self.edges[None, :], total_timesteps, axis=0) 

    N_dims = self.locations.shape[2]
    shift = np.arange(total_timesteps)[:, None, None] * N_dims  #5
    edge_index += shift
    edge_index = edge_index.reshape(2, -1)   #6

    x = torch.tensor(nodes, dtype=torch.float)  #7
    edge_index = torch.tensor(edge_index, dtype=torch.long) 
    mask_indices = np.arange(self.window_size * self.locations.shape[2], total_timesteps * self.locations.shape[2])  #8
    mask_indices = torch.tensor(mask_indices, dtype=torch.long)

    if self.transform:
        trnsfm_data = [node_min, node_max]
        return Data(x=x, edge_index=edge_index, mask_indices=mask_indices, trnsfm=trnsfm_data)
    
    return Data(x=x, edge_index=edge_index, mask_indices=mask_indices)
  • #1 将位置和速度数据连接成每个节点的特征
  • #2 确定掩码大小
  • #3 如果transform为True,则应用归一化
  • #4 对每个时间步,重复边数据以涵盖所有时间步(过去+未来)
  • #5 对边缘索引应用位移
  • #6 将边缘索引展平为二维
  • #7 将所有数据转换为PyTorch张量
  • #8 计算掩码节点的索引

接下来,我们希望将不同时间步的所有节点组合成一个大的图,包含所有单独的帧。这将产生一个包含所有不同时间步的邻接矩阵。(有关时间邻接矩阵的进一步细节,请参见本章最后的6.5节。)为了处理我们的姿态估计数据,我们首先为每个时间步构建邻接矩阵,如列表6.6所示,并包括在列表6.5中。

如图6.9所示,过程开始时通过多时间步表示图数据,其中每个时间步被视为一个独立的层(步骤1)。所有节点都有节点特征数据(图中未显示)。对于我们的应用,节点特征数据包括位置和速度信息。

在一个时间步内的节点通过时间步内的边连接,即同一时间步层中的节点之间的连接(步骤2)。这些边确保在特定时间步的每个图是内部一致的。此时节点之间尚未跨时间步连接。

为了纳入时间关系,添加了跨时间步的边(即,连接不同时间步层的节点)(步骤3)。这些边允许在不同时间步之间流动信息,使得图数据能够进行时间建模。

为了预测未来的值,我们将最后一个时间步的节点进行掩蔽,表示未知数据(步骤4)。这些被掩蔽的节点被视为预测任务的目标。它们的值是未知的,但可以通过利用前面时间步中未掩蔽节点的特征和关系来推断。

推理过程(步骤5)使用前一个时间步(t = 0 和 t = 1)中未掩蔽节点的已知特征来预测t = 2中被掩蔽节点的特征。虚线箭头显示了如何从未掩蔽的节点流向掩蔽节点,表明预测依赖于先前图数据的关系。这将任务转化为一个节点预测问题,目标是基于未掩蔽节点的关系和特征,估计掩蔽节点的特征。

image.png

代码清单 6.6 构建邻接矩阵

total_timesteps = self.\
    window_size + self.mask_size  #1
edge_index = np.repeat(self.edges[None, :],\
    total_timesteps, axis=0)

shift = np.arange(total_timesteps)[:, None, \
    None] * num_nodes_per_timestep  #2
edge_index += shift  #3
edge_index = edge_index.reshape(2, -1)  #4
  • #1 将边重复,直到达到总的时间步数(过去 + 未来)
  • #2 为每个时间步创建一个偏移量
  • #3 将偏移量应用到边的索引
  • #4 将边的索引展平为二维

现在我们已经有了邻接矩阵,接下来的步骤是构建一个可以预测未来时间步的模型。这里,我们将使用在第 4 章中介绍的 GAT 模型 [7]。我们选择这个 GNN,因为它比其他 GNN 更具表达力,且能够考虑不同的时间和空间信息。模型架构如代码清单 6.7 所示。

代码清单 6.7 定义 GAT 模型

class GAT(torch.nn.Module):
    def __init__(self, n_feat,
                 hidden_size=32,
                 num_layers=3,
                 num_heads=1,
                 dropout=0.2,
                 mask_size=10):
        super(GAT, self).__init__()

        self.num_layers = num_layers
        self.heads = num_heads
        self.n_feat = n_feat
        self.hidden_size = hidden_size
        self.gat_layers = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()
        self.dropout = nn.Dropout(dropout)
        self.mask_size = mask_size

        gat_layer = GATv2Conv(self.n_feat, self.hidden_size, heads=num_heads)  #1
        self.gat_layers.append(gat_layer)  #1
        middle_size = self.hidden_size * num_heads
        batch_layer = nn.BatchNorm1d(num_features=middle_size)  #2
        self.batch_norms.append(batch_layer)  #2

        for _ in range(num_layers - 2):  #3
            gat_layer = GATv2Conv(input_size, self.hidden_size, heads=num_heads)
            self.gat_layers.append(gat_layer)
            batch_layer = nn.BatchNorm1d(num_features=middle_size)  #4
            self.batch_norms.append(batch_layer)

        gat_layer = GATv2Conv(middle_size, self.n_feat)
        self.gat_layers.append(gat_layer)  #5

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        for i in range(self.num_layers):
            x = self.gat_layers[i](x, edge_index)
            if i < self.num_layers - 1:  #6
                x = self.batch_norms[i](x)  #6
                x = torch.relu(x)  #6
                x = self.dropout(x)  #6

        n_nodes = edge_index.max().item() + 1  #7
        x = x.view(-1, n_nodes, self.n_feat)
        return x[-self.mask_size:].view(-1, self.n_feat)
  • #1 第一个 GAT 层
  • #2 第一个 GAT 层的 BatchNorm 层
  • #3 中间 GAT 层
  • #4 中间 GAT 层的 BatchNorm 层
  • #5 最后一个 GAT 层
  • #6 不对最后一个 GAT 层的输出应用 BatchNorm 和 Dropout
  • #7 只输出最后一帧

这个模型遵循了第 4 章中概述的基本结构。我们为模型定义了层数和头数,以及输入大小,这取决于我们要预测的特征数量。每个 GAT 层都有一个隐藏大小,我们加入了 dropout 和 batch normalization 来提高性能。然后我们遍历模型中的层数,确保维度匹配以符合目标输出。我们还定义了前向函数,用于预测被遮挡节点的节点特征。通过将每个时间步展开成一个更大的图,我们开始引入时间效应,作为模型可以学习的额外网络结构。

现在,我们已经定义了模型和数据集,接下来让我们开始训练模型,看看它的表现如何。回顾一下,RNN 和 GRU 分别在测试集上的准确率为 60% 和 65%。在代码清单 6.8 中,我们展示了 GAT 模型的训练循环。这个训练循环遵循了前几章使用的结构。我们使用 MSE 作为损失函数,并将学习率设置为 0.0005。我们使用 GAT 计算被遮挡节点的节点特征,然后将这些特征与存储在数据中的真实数据进行比较。首先训练模型,然后使用验证集比较模型预测的结果。需要注意的是,由于我们现在要预测多个图序列,这个训练循环比以前的模型需要更多时间。在 Google Colab 的 V100 GPU 上,训练时间不到一小时。

代码清单 6.8 GAT 训练循环

lr = 0.001
criterion = torch.nn.MSELoss()                            #1
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

for epoch in tqdm(range(epochs), ncols=300):
    model.train()
    train_loss = 0.0
    for data in train_dataset:
        optimizer.zero_grad()
        out = model(data)  #2

        loss = criterion(out, \
        data.y.reshape(out.shape[0], -1))  #3
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    model.eval()  #4
    val_loss = 0.0  #4
    with torch.no_grad():  #4
        for val_data in val_dataset:  #4
            val_out = model(val_data)  #5
            val_loss += criterion(out, \
            data.y.reshape(out.shape[0], \
            -1)).item()  #6

    val_loss /= len(val_dataset)
    train_loss /= len(train_dataset)
  • #1 使用学习率初始化损失和优化器
  • #2 为输入生成模型的预测
  • #3 计算输出和目标之间的损失
  • #4 验证循环
  • #5 为输入生成模型的预测
  • #6 计算输出和目标之间的损失

最后,我们使用测试集和代码清单 6.9 中显示的代码来测试训练后的模型。

代码清单 6.9 GAT 测试循环

test_loss = 0
for test_data in test_dataset:
    test_out = model(test_data)  #1
    test_loss += criterion(out, \
    data.y.reshape(out.shape[0], -1)).item()  #2
  • #1 为输入生成模型的预测
  • #2 计算输出和目标之间的损失

我们发现这种简单的方法无法预测姿势。我们的总体测试准确率为 55%,且预测的图形与我们预期的姿势外观有很大不同。这是因为我们现在将大量数据存储在一个图中。我们将节点特征和时间数据压缩到一个图中,并且在定义模型时没有强调时间属性。可以通过使用时间编码提取未使用的边数据来改进这一点,例如在时间 GAT (TGAT) 模型中。TGAT 将边视为动态的,而不是静态的,因此每条边还会编码时间戳。

然而,在没有这些时间数据的情况下,我们的模型变得过于表达性,导致姿势的整体结构与原始结构产生了显著偏离,如图 6.10 中所示。接下来,我们将研究如何将两种方法的优点结合起来,构建一个使用 RNN 预测的 GNN,学习每个图快照。

image.png

6.4 神经关系推理(NRI)

我们的RNN专注于时间序列数据,但忽略了潜在的关系数据。这导致模型能够在平均意义上朝正确的方向移动,但并没有很好地改变各个传感器的位置。另一方面,我们的GAT模型忽略了时间数据,将所有单独的时间图编码为一个单一的图,并尝试对未知的未来图进行节点预测。该模型导致传感器的移动剧烈,最终的图像与我们预期的人工移动方式相差甚远。

如前所述,神经关系推理(NRI)是一种略有不同的方法,它使用更复杂的编码框架,将RNN和GNN的优点结合在一起[2]。该模型的架构如图6.11所示。具体而言,NRI使用自编码器结构来嵌入每个时间步的信息。因此,嵌入架构类似于我们在第5章中讨论的GAE,应用于整个图形。然后,使用RNN更新这些编码的图形数据。一个关键点是,NRI演化了嵌入的潜在表示。

image.png

让我们探讨一下这个模型如何应用于我们的人体姿态估计问题,以便更好地理解模型中的不同组件。我们将使用与训练过程中遮蔽一些数据并在测试时识别这些遮蔽节点相同的格式。回想一下,这相当于推断我们视频中的未来帧。然而,现在我们需要改变模型架构和损失函数。我们需要改变模型架构,以适应新的自编码器结构,并且需要调整损失函数,以包括最小化重构损失和Kullback-Leibler散度(KL散度)。有关NRI模型和相关更改的更多信息,请参见本章末尾的第6.5节。

NRI模型的基类代码见清单6.10。如代码中所示,在调用此类时,我们需要定义编码器和解码器。除了编码器和解码器外,还有一些其他模型特定的细节需要注意。首先,我们需要定义变量的数量。这与我们图中的节点数量有关,而不是每个节点的特征数量。在我们的案例中,这将是31,代表每个不同的传感器追踪一个关节的位置。我们还需要定义节点之间的不同边类型。这将是1或0,表示是否存在边。

我们假设节点或传感器之间的连接方式是静态的,也就是说图的结构不变。请注意,这个模型也支持动态图,其中连接性会随时间变化,例如,当不同的球员在篮球场上移动时。球员总数是固定的,但可以传球给的球员数是变化的。实际上,这个模型也被用来预测不同球员如何传球,使用了NBA的录像。

最后,这个模型需要设置一些超参数,包括Gumbel温度和先验方差。Gumbel温度控制在进行离散采样时探索和开发之间的权衡。在这里,我们需要使用离散概率分布来预测边的类型。我们在第6.5节中将详细讨论这一点。先验方差反映了在开始之前,我们对图的连接性有多大的不确定性。我们需要设置它,因为模型假设我们并不清楚连接性。实际上,模型学习出最有助于提高预测的连接性。这正是我们在调用_initialize_log_prior函数时设置的内容。我们告诉模型我们对可能的连接模式的最佳猜测。例如,如果我们将这个模型应用于一支运动队,我们可以使用一个高均值的高斯分布来表示球员之间频繁传球的边,或者是同一队球员之间的边。

为了演示我们的模型,我们将假设一个均匀的先验,这意味着所有边的概率是一样的,或者用日常话来说就是“我们不知道”。先验方差为每个边设置了我们的不确定性范围。在以下代码中,我们将其设置为5 × 10^–5,以确保数值稳定性,但由于我们的先验是均匀的,它不应该有太大影响。

清单6.10 NRI模型的基类

class BaseNRI(nn.Module):
    def __init__(self, num_vars, encoder, decoder,
            num_edge_types=2,
            gumbel_temp=0.5, 
            prior_variance=5e-5):
        super(BaseNRI, self).__init__()
        self.num_vars = num_vars  #1
        self.encoder = encoder  #2
        self.decoder = decoder  #3
        self.num_edge_types = num_edge_types 
        self.gumbel_temp = gumbel_temp  #4
        self.prior_variance = prior_variance  #5

        self.log_prior = self._initialize_log_prior()

    def _initialize_log_prior(self): 
         prior = torch.zeros(self.num_edge_types)
         prior.fill_(1.0 / self.num_edge_types)  #6
         log_prior = torch.log(prior)\
.unsqueeze(0).unsqueeze(0)  #7
         return log_prior.cuda(non_blocking=True)

注释说明:

  • #1 模型中的变量数量
  • #2 编码器神经网络
  • #3 解码器神经网络
  • #4 用于采样类别变量的Gumbel温度
  • #5 先验方差
  • #6 将先验张量填充为均匀概率
  • #7 取对数并添加两个单例维度

正如我们在第5章中发现的,变分自编码器(VAE)具有两部分损失——重构误差和表示数据分布属性的误差——通过KL散度来捕捉。总损失函数在清单6.11中给出。

我们的编码器接收边的嵌入,并输出边类型的对数概率。Gumbel-Softmax函数将这些离散的logits转换为可微分的连续分布。解码器接收这个分布和边的表示,然后将其转换回节点数据。此时,我们可以使用VAE的标准损失计算方法,因此我们计算重构损失作为MSE和KL散度。有关VAE损失和KL散度计算的进一步细节,请重新查看第5章。

清单6.11 NRI模型的损失函数

def calculate_loss(self, inputs,
    is_train=False,
    teacher_forcing=True,
    return_edges=False,
    return_logits=False):

    encoder_results = self.encoder(inputs)
    logits = encoder_results['logits']
    hard_sample = not is_train
    edges = F.gumbel_softmax\
           (logits.view(-1, self.num_edge_types),
           tau=self.gumbel_temp,
           hard=hard_sample).view\
                   (logits.shape)  #1

    output = self.decoder(inputs[:, :-1], edges)

    if len(inputs.shape) == 3: \
        target = inputs[:, 1:] 
    else:
        target = inputs[:, 1:, :, :]

    loss_nll = F.mse_loss(\
    output, target) / (2 * \
    self.prior_variance)  #2

    probs = F.softmax(logits, dim=-1)
    log_probs = torch.log(probs + 1e-16)  #3
    loss_kl = (probs * \
    (log_probs - torch.log(\
    torch.tensor(1.0 /  #4
       self.num_edge_types)))).\
    sum(-1).mean() 

    loss = loss_nll + loss_kl

    return loss, loss_nll, loss_kl, logits, output

注释说明:

  • #1 使用PyTorch的功能API计算Gumbel-Softmax
  • #2 高斯分布的负对数似然(NLL)
  • #3 加上一个小常数以避免取对数时为零
  • #4 使用均匀类别分布计算KL散度

最后,我们需要让模型能够预测传感器的未来轨迹。预测图未来状态的代码见清单6.12。一旦我们训练了编码器和解码器,这个函数相对简单。我们将当前图传递给编码器,它返回一个潜在的表示,指示是否存在边。然后,我们将这些概率转换为适当的分布,使用Gumbel-Softmax,并将其传递给解码器。解码器的输出就是我们的预测。我们可以直接获取预测结果,或者同时获取预测结果和边的存在情况。

清单6.12 预测未来

def predict_future(self, inputs, prediction_steps, 
  return_edges=False, 
  return_everything=False): #1
    encoder_dict = self.encoder(inputs) #1
    logits = encoder_dict['logits'] 
    edges = nn.functional.gumbel_softmax(  #2
        logits.view(-1, \
        self.num_edge_types),  
        tau=self.gumbel_temp,\
        hard=True).view(logits.shape) 
    tmp_predictions, decoder_state =\
      self.decoder(  #3
      inputs[:, :-1],  #3
      edges,  #3
      return_state=True  #3
    )  #3
    predictions = self.decoder(  #4
      inputs[:, -1].unsqueeze(1),   #4
      edges,   #4
      prediction_steps=prediction_steps,   #4
      teacher_forcing=False,  #4
      state=decoder_state  #4
    ) #4
    if return_everything:  #5
        predictions = torch.cat([\
          tmp_predictions,  #5
          predictions  #5
        ], dim=1)  #5

    return (predictions, edges)\
        if return_edges else predictions  #6

注释说明:

  • #1 运行编码器以获取边类型的logits
  • #2 应用Gumbel-Softmax到边
  • #3 运行解码器获取初始预测和解码器状态
  • #4 使用最后的输入和解码器状态预测未来步骤
  • #5 如果需要,连接初始预测和未来预测
  • #6 如果指定,返回预测和边

这是NRI模型的基础。我们有一个编码器,将初始节点数据转换为边的预测,然后是解码器,用于给出节点数据的未来轨迹。

6.4.1 编码姿态数据

现在我们已经了解了NRI模型的不同部分,让我们来定义编码器。这个编码器将作为瓶颈来简化问题。经过编码后,我们将得到边数据的低维表示,因此在此阶段不需要担心时间数据。然而,通过将时间数据一起提供,我们将时间结构传递到潜在空间中。具体而言,编码器从输入数据中提取时间模式和关系,并将这些信息保留在压缩的低维表示中。这使得从低维表示中解码变得更加容易,从而简化了姿态预测问题。

实现编码器有几个子步骤。首先,我们传递输入数据,该数据由不同实验中的不同帧的不同传感器组成。然后,编码器接收这些数据x,并执行消息传递步骤,将边数据转换为节点数据,再转换回边数据。接着,边数据再次转换为节点数据,并最终在潜在空间中进行编码。这等同于三个消息传递步骤:从边到节点、从边到边,再从边到节点。重复的转换步骤有助于通过反复的消息传递进行信息聚合,并捕捉图中的高阶交互。通过节点和边之间的反复转换,模型可以了解局部和全局的结构信息。

在本书中,我们探索了如何使用消息传递将节点或边的特征转换为复杂的节点或边表示。这是所有GNN方法的核心。NRI模型与我们之前探索过的方法稍有不同,因为消息是通过节点和边之间传递的,而不是节点到节点或边到边的传递。为了明确这些步骤的作用,我们将不使用PyG,而是直接用PyTorch编写我们的模型。

在清单6.13中,我们展示了编码器的基类,它需要几个关键功能。首先,请注意,我们没有描述用于编码数据的实际神经网络。我们稍后会介绍这一点。相反,我们有两个消息传递函数,edge2nodenode2edge,以及一个编码函数one_hot_recv

清单 6.13 编码器基类

class BaseEncoder(nn.Module):
    def __init__(self, num_vars):
        super(BaseEncoder, self).__init__()
        self.num_vars = num_vars
        edges = torch.ones(num_vars) - torch.eye(num_vars)  #1
        self.send_edges, self.recv_edges = torch.where(edges)  #2

        one_hot_recv = torch.nn.functional.one_hot(  #3
            self.recv_edges,  #3
            num_classes=num_vars  #3
        ) #3
        self.edge2node_mat = \
            nn.Parameter(one_hot_recv.float().T, requires_grad=False)  #4

    def node2edge(self, node_embeddings):
        send_embed = node_embeddings[:, self.send_edges]  #5
        recv_embed = node_embeddings[:, self.recv_edges] 
        return torch.cat([send_embed, recv_embed], dim=2)  #6

    def edge2node(self, edge_embeddings):
        incoming = torch.matmul(self.edge2node_mat, edge_embeddings)  #7
        return incoming / (self.num_vars - 1)  #8
  • #1 创建一个表示变量之间边的矩阵
  • #2 查找边存在的索引
  • #3 为接收边创建一个独热编码表示
  • #4 创建一个用于边到节点转换的参数张量
  • #5 提取发送者和接收者的嵌入
  • #6 连接发送者和接收者的嵌入
  • #7 将边嵌入与边到节点矩阵相乘
  • #8 对传入的嵌入进行归一化

我们编码器类的第一步是构建邻接矩阵。在这里,我们假设图是完全连接的,即所有节点都与其他节点相连,但不与自己相连。node2edge函数接收节点嵌入数据并识别这些消息的发送方向。图6.12展示了我们如何构建邻接矩阵的一个例子。

image.png

接下来的函数调用确定哪些节点发送或接收数据,通过返回两个向量,这些向量包含连接节点的行和列。回想一下,在邻接矩阵中,行表示接收节点,列表示发送节点。输出结果如下:

send_edges = tensor([0, 0, 1, 1, 2, 2])
recv_edges = tensor([1, 2, 0, 2, 0, 1])

我们可以将其解释为:位于第0行的节点将数据发送到位于第1列和第2列的节点,以此类推。这使我们能够提取节点之间的边。一旦构建了节点嵌入,我们就可以使用发送和接收数据将节点数据转换为边数据。这就是node2edge函数的原理。

接下来我们需要的函数是如何基于我们的边嵌入构建edge2node。我们首先构建一个edge2node矩阵。在这里,我们使用了一种独热编码方法,将接收边转换为独热编码表示。具体来说,我们创建一个矩阵,每一行表示该类别(接收节点)是否存在。对于我们简单的三节点案例,接收边的独热编码方法如图6.13所示。

然后,我们对这个矩阵进行转置,交换行和列,这样它的维度将变为(节点数,边数),并将其转换为PyTorch参数,以便我们可以对其进行微分。获得edge2node矩阵后,我们将其与边嵌入相乘。我们的边嵌入的形状是(边数,嵌入大小),因此将edge2node矩阵与边嵌入相乘会得到一个形状为(节点数,嵌入大小)的对象。这就是我们新的节点嵌入!最后,我们对这个矩阵按节点数进行归一化,以确保数值稳定性。

这一部分是理解模型中消息传递步骤的关键。(有关消息传递的更多信息,请回顾第2章和第3章。)正如那里所讨论的,一旦我们有了一种原则性的方法来在节点、边或它们的组合之间传递消息,我们就可以将神经网络应用于这些嵌入,获得非线性表示。为了做到这一点,我们需要定义我们的嵌入架构。完整的编码器代码在清单6.14中给出。

image.png

RefMLPEncoder 如清单6.14所示。这个编码器使用了四个MLP(多层感知器)进行消息处理,每个MLP都具有指数线性单元(ELU)激活函数和批量归一化(在本章的代码库中定义了RefNRIMLP)。

注意
指数线性单元(ELU)是一种激活函数,它在多个层之间平滑输出,并防止梯度消失。与ReLU相比,ELU对负输入有内置的平滑梯度,并允许负输出。

网络的最后部分(self.fc_out)是一个线性层序列,层与层之间有ELU激活,最终以一个线性层结束,该层输出所需的嵌入或预测。该序列的最后一层是一个全连接层。

清单 6.14 NRI MLP 编码器

class RefMLPEncoder(BaseEncoder):
    def __init__(self, 
            num_vars=31, 
            input_size=6, 
            input_time_steps=50, 
            encoder_mlp_hidden=256, 
            encoder_hidden=256, 
            num_edge_types=2, 
            encoder_dropout=0.):
        super(RefMLPEncoder, self).__init__(num_vars)
        inp_size = input_size * input_time_steps
        hidden_size = encoder_hidden
        num_layers = 3
        self.input_time_steps = input_time_steps

        self.mlp1 = RefNRIMLP(
            inp_size, hidden_size, hidden_size, encoder_dropout)  #1
        self.mlp2 = RefNRIMLP(
            hidden_size*2, hidden_size, hidden_size, encoder_dropout)
        self.mlp3 = RefNRIMLP(
            hidden_size, hidden_size, hidden_size, encoder_dropout)
        mlp4_inp_size = hidden_size * 2
        self.mlp4 = RefNRIMLP(
            mlp4_inp_size, hidden_size, hidden_size, encoder_dropout)

        layers = [nn.Linear(
            hidden_size, encoder_mlp_hidden), 
            nn.ELU(inplace=True)]  #2
        layers += [nn.Linear(
            encoder_mlp_hidden, encoder_mlp_hidden), 
            nn.ELU(inplace=True)] * (num_layers - 2)
        layers.append(nn.Linear(
            encoder_mlp_hidden, num_edge_types))
        self.fc_out = nn.Sequential(*layers)
        self.init_weights()

#1 定义了MLP层。RefNRIMLP是一个带有批量归一化的2层全连接ELU网络。
#2 定义了最终的全连接层

在这里,我们定义了与编码器相关的架构细节。如前所述,我们使用num_vars变量表示31个传感器。特征数为6,即网络的input_size。我们的训练和验证集的时间步数仍为50,编码器网络的大小为256。num_edge_types为2,并假设权重没有丢弃(dropout为0)。然后,我们初始化了我们的网络,这些网络是典型的MLP,已在我们的共享代码库中描述。网络包括一个批量归一化层和两个全连接层。定义网络后,我们还预初始化了权重,如清单6.15所示。在这里,我们遍历所有不同的层,然后使用Xavier初始化方法初始化权重。这确保了各层的梯度大致相同,这减少了损失快速发散(即梯度爆炸)的风险。对于组合不同架构的多个网络,这一步非常重要。我们还将初始偏差设置为0.1,这进一步有助于训练的稳定性。

清单 6.15 权重初始化

def init_weights(self):
    for m in self.modules():
        if isinstance(m, nn.Linear):  #1
            nn.init.xavier_normal_(m.weight.data)  #2
            m.bias.data.fill_(0.1)  #3
#1 仅适用于线性层
#2 使用Xavier正态初始化来初始化权重
#3 将偏差设置为0.1

最后,我们需要定义我们的前向传播方法,如清单6.16所示。这里就是我们的消息传递步骤发生的地方。

清单 6.16 编码器前向传播

def forward(self, inputs, state=None, return_state=False):
    if inputs.size(1) > self.input_time_steps:
        inputs = inputs[:, -self.input_time_steps:]
    elif inputs.size(1) < self.input_time_steps:
        begin_inp = inputs[:, 0:1].expand(
            -1, 
            self.input_time_steps-inputs.size(1),
            -1, -1
        )
        inputs = torch.cat([begin_inp, inputs], dim=1) #1

    x = inputs.transpose(1, 2).contiguous()  #1
    x = x.view(inputs.size(0), inputs.size(2), -1) 

    x = self.mlp1(x)  #2
    x = self.node2edge(x)  #3
    x = self.mlp2(x)  #4

    x = self.edge2node(x)  #5
    x = self.mlp3(x)

    x = self.node2edge(x)  #6
    x = self.mlp4(x)

    result =  self.fc_out(x)  #7
    result_dict = {
        'logits': result,
        'state': inputs,
    }
    return result_dict
#1 新的形状:[num_sims, num_atoms, num_timesteps*num_dims]
#2 通过第一个MLP层(每个节点的二层ELU网络)
#3 将节点嵌入转换为边嵌入
#4 通过第二个MLP层
#5 将边嵌入转换回节点嵌入
#6 再次将节点嵌入转换为边嵌入
#7 最终的全连接层获取logits

我们的编码器使得模型能够将不同的传感器图像帧集转换为边概率的潜在表示。接下来,让我们探讨如何构建解码器,将潜在的边概率转换为基于近期传感器数据的轨迹。

6.4.2 使用GRU解码姿态数据

为了将潜在表示转换为未来的帧,我们需要考虑轨迹的时间演变。为此,我们训练一个解码器网络。在这里,我们将遵循NRI论文的原始结构[2],使用GRU作为我们的RNN。

我们在6.2.2节中介绍了GRU的概念。简而言之,门控循环单元(GRU)是一种RNN,使用门控过程来捕捉数据中的长期行为。它由两种类型的门组成——重置门和更新门。

对于NRI模型,我们将GRU应用于边数据,而不是应用于整个图。更新门将用于根据接收的数据确定应该更新节点的隐藏状态多少,重置门则决定应该擦除或“遗忘”多少。换句话说,我们将使用GRU来预测节点的未来状态,这基于来自编码器网络的边类型概率。

让我们逐步构建这个过程。解码器的初始化代码如清单6.17所示。首先,我们注意到传递给该网络的一些变量。我们再次定义了图中变量或节点的数量(31),以及输入特征的数量(6)。我们假设权重没有丢弃,并且每一层的隐藏层大小为64。同样,我们需要明确指出,解码器应该预测两种不同类型的边。我们在做预测时将跳过第一种边类型,因为它表示没有边。

一旦定义了输入参数,我们就可以介绍网络架构。第一层是一个简单的线性网络,它需要将输入维度乘以2,以考虑编码器提供的均值和方差,并为每个边类型定义这个网络。接下来,我们定义第二层来进一步增强网络的表达能力。这两层线性层的输出将传递到我们的RNN,即GRU。这里,我们必须使用自定义的GRU,以同时考虑节点数据和边数据。GRU的输出将传递到另外三个神经网络层,以提供未来的预测。最后,我们需要定义我们的edge2node矩阵以及发送和接收节点,正如我们在编码器中所做的那样。

清单 6.17 RNN解码器

class GraphRNNDecoder(nn.Module):
    def __init__(self, 
        num_vars=31, 
        input_size=6, 
        decoder_dropout=0., 
        decoder_hidden=64, 
        num_edge_types=2, 
        skip_first=True):
        super(GraphRNNDecoder, self).__init__()
        self.num_vars = num_vars
        self.msg_out_shape = decoder_hidden
        self.skip_first_edge_type = skip_first
        self.dropout_prob = decoder_dropout
        self.edge_types = num_edge_types

        self.msg_fc1 = nn.ModuleList(
            [nn.Linear(2 * decoder_hidden, decoder_hidden) for _ in range(self.edge_types)])  #1
        self.msg_fc2 = nn.ModuleList(
            [nn.Linear(decoder_hidden, decoder_hidden) for _ in range(self.edge_types)])

        self.custom_gru = CustomGRU(input_size, decoder_hidden)  #2

        self.out_fc1 = nn.Linear(decoder_hidden, decoder_hidden)  #3
        self.out_fc2 = nn.Linear(decoder_hidden, decoder_hidden)
        self.out_fc3 = nn.Linear(decoder_hidden, input_size)

        self.num_vars = num_vars
        edges = np.ones(num_vars) - np.eye(num_vars)
        self.send_edges = np.where(edges)[0]
        self.recv_edges = np.where(edges)[1]
        self.edge2node_mat = torch.FloatTensor(encode_onehot(self.recv_edges))
        self.edge2node_mat = self.edge2node_mat.cuda(non_blocking=True)

#1 边相关的层
#2 GRU层
#3 全连接层

在清单6.18中,我们提供了自定义GRU的架构。这个网络的第一层架构与典型的GRU结构相同。我们定义了三个隐藏层,分别表示由hidden_rinput_r定义的重置门,hidden_iinput_i定义的更新门,以及hidden_hinput_h定义的激活网络。然而,前向网络需要考虑来自编码器的消息传递输出的聚合消息。这在前向传播中得以体现。我们将边概率(在agg_msgs中)和输入的节点数据一起传递,这些结合起来将返回未来的预测。这可以在我们基类NRI中的predict_future代码中看到:

predictions = self.decoder(inputs[:, -1].unsqueeze(1), edges,
prediction_steps=prediction_steps, teacher_forcing=False, 
state=decoder_state)

我们的解码器接收图的最后一个时间帧。来自编码器的边数据也会传递给解码器。

清单 6.18 自定义GRU网络

class CustomGRU(nn.Module):
    def __init__(self, input_size, n_hid, num_vars=31):
        super(CustomGRU, self).__init__()
        self.num_vars = num_vars
        self.hidden_r = nn.Linear(n_hid, n_hid, bias=False)  #1
        self.hidden_i = nn.Linear(n_hid, n_hid, bias=False) 
        self.hidden_h = nn.Linear(n_hid, n_hid, bias=False) 

        self.input_r = nn.Linear(input_size, n_hid, bias=True)  #2
        self.input_i = nn.Linear(input_size, n_hid, bias=True) 
        self.input_n = nn.Linear(input_size, n_hid, bias=True) 

    def forward(self, inputs, agg_msgs, hidden):
        inp_r = self.input_r(inputs).view(inputs.size(0), self.num_vars, -1)
        inp_i = self.input_i(inputs).view(inputs.size(0), self.num_vars, -1)
        inp_n = self.input_n(inputs).view(inputs.size(0), self.num_vars, -1)

        r = torch.sigmoid(inp_r + self.hidden_r(agg_msgs))  #3
        i = torch.sigmoid(inp_i + self.hidden_i(agg_msgs))  #4
        n = torch.tanh(inp_n + r * self.hidden_h(agg_msgs))  #5
        hidden = (1 - i) * n + i * hidden  #6

        return hidden

#1 定义重置、输入和新门的隐藏层转换
#2 定义重置、输入和新门的输入层转换
#3 计算重置门的激活
#4 计算输入门的激活
#5 计算新门的激活
#6 更新隐藏状态

解码器网络的输出是未来预测的时间步。为了更好地理解这一点,让我们看看解码器的前向传播方法,如清单6.19所示。我们的前向传播方法接收输入和采样的边来构建预测。还有四个额外的参数帮助控制行为。首先,我们定义了一个teacher_forcing变量。教师强迫是一种常用于训练序列模型(如RNN)的方法。如果为True,我们使用真实的图(地面真相)来预测下一个时间帧。如果为False,我们使用模型上一时刻的输出。这确保模型在训练过程中不会被错误预测所误导。接下来,我们包括一个return_state变量,它允许我们访问解码器网络给出的隐藏表示。我们在预测未来图的演化时使用这个表示,如下所示:

tmp_predictions, decoder_state = self.decoder(inputs[:, :-1], edges, return_state=True)
predictions = self.decoder(inputs[:, -1].unsqueeze(1), edges, prediction_steps=prediction_steps, teacher_forcing=False, state=decoder_state)

现在让我们讨论预测过程。首先,我们预测一个临时预测集。然后,我们使用隐藏表示预测尽可能多的未来步骤。这在我们希望预测多个时间步时特别有用,正如我们在模型的测试阶段所展示的那样。这个过程由prediction_steps变量控制,它告诉我们要在RNN中循环多少次,也就是我们想预测多少个未来时间步。最后,我们有一个state变量,用于控制传递给解码器的信息。当它为空时,我们初始化一个全为零的张量,以确保没有信息被传递。否则,我们将使用来自前一时间步的信息。

清单 6.19 解码器前向传播

def forward(self, inputs, sampled_edges,
    teacher_forcing=False,
    return_state=False,
    prediction_steps=-1,
    state=None):

    batch_size, time_steps, num_vars, num_feats = inputs.size()
    pred_steps = prediction_steps if prediction_steps > 0 else time_steps  #1

    if len(sampled_edges.shape) == 3:  #2
        sampled_edges = sampled_edges.unsqueeze(1) 
        sampled_edges = sampled_edges.expand(batch_size, pred_steps, -1, -1) 

    if state is None:  #3
        hidden = torch.zeros(batch_size, num_vars, self.msg_out_shape, device=inputs.device)  #3
    else:  #3
        hidden = state  #3
        teacher_forcing_steps = time_steps  #4

    pred_all = []
    for step in range(pred_steps):  #5
        if step == 0 or (teacher_forcing and step < teacher_forcing_steps): 
            ins = inputs[:, step, :] 
        else: 
            ins = pred_all[-1] 

        pred, hidden = self.single_step_forward(  #6
            ins, sampled_edges[:, step, :], hidden)  #6
        pred_all.append(pred)

    preds = torch.stack(pred_all, dim=1)

    return (preds, hidden) if return_state else preds  #7
#1 确定预测步数
#2 如果需要,扩展采样边的张量
#3 如果没有提供,则初始化隐藏状态
#4 确定应用教师强迫的步数
#5 根据教师强迫决定本步的输入
#6 使用计算出的`ins`通过单步前向传播
#7 返回预测和隐藏状态

为了预测未来的时间步,我们进行一个额外的前向传播,这基于单个时间步,如清单6.20所定义。这时,网络执行额外的消息传递步骤。我们从编码器的边概率中获取接收节点和发送节点。我们忽略第一种边类型,因为这些节点是未连接的,然后网络通过不同的边类型网络循环,获取所有依赖于边的数据。这个关键步骤使得我们的预测依赖于图数据。然后,GRU从连接的节点获取消息,以便预测轨迹。在这一步,我们正在学习如何根据已知的身体连接关系预测身体的行走状态。输出即为预测的身体传感器轨迹以及网络数据,解释了为何做出这些预测,这些数据编码在隐藏权重中。这完成了NRI模型的姿态估计。

清单 6.20 解码器单步前向传播

def single_step_forward(self, inputs, rel_type, hidden):
    receivers = hidden[:, self.recv_edges, :]  #1
    senders = hidden[:, self.send_edges, :]  #1

    pre_msg = torch.cat([receivers, senders], dim=-1)  #2

    all_msgs = torch.zeros(
        pre_msg.size(0), 
        pre_msg.size(1), 
        self.msg_out_shape, 
        device=inputs.device
    )

    start_idx = 1 if self.skip_first_edge_type else 0
    norm = float(len(self.msg_fc2) - start_idx)

    for i in range(start_idx, len(self.msg_fc2)):  #3
        msg = torch.tanh(self.msg_fc1[i](pre_msg))  #3
        msg = F.dropout(msg, p=self.dropout_prob)  #3
        msg = torch.tanh(self.msg_fc2[i](msg))  #3
        msg = msg * rel_type[:, :, i:i+1]  #3
        all_msgs += msg / norm  #3

    agg_msgs = all_msgs.transpose(-2, -1)  #4
    agg_msgs = agg_msgs.matmul(self.edge2node_mat) 
    agg_msgs = agg_msgs.transpose(-2, -1) / (self.num_vars - 1) 

    hidden = self.custom_gru(inputs, agg_msgs, hidden)  #5

    pred = F.dropout(F.relu(self.out_fc1(hidden)), p=self.dropout_prob)  #6
    pred = F.dropout(F.relu(self.out_fc2(pred)), p=self.dropout_prob) 
    pred = self.out_fc3(pred) 

    pred = inputs + pred   
    return pred, hidden
#1 节点到边的步骤
#2 消息大小:[batch, num_edges, 2*msg_out]
#3 对每个边类型运行一个单独的MLP
#4 每个节点汇总所有消息
#5 GRU风格的门控聚合
#6 构建输出MLP

6.4.3 训练NRI模型

现在我们已经定义了模型的不同部分,接下来让我们训练模型并看看它的表现。为了训练我们的模型,我们将采取以下步骤:

  1. 训练一个编码器,将传感器数据转换为边概率的表示,指示一个传感器是否与另一个传感器连接。
  2. 训练一个解码器,预测未来的轨迹,条件是存在一个边将不同的传感器连接起来。
  3. 使用GRU运行解码器,预测未来的轨迹,并将边概率传递给它。
  4. 基于重构的姿态减少损失。该损失包含两个部分:重构损失和KL散度。
  5. 重复步骤1至步骤4,直到训练收敛。

如图6.14所示,训练循环在清单6.21中给出。

image.png

清单 6.21 NRI 训练循环

pbar = tqdm(range(start_epoch, num_epochs + 1), desc='Epochs')
for epoch in pbar:
    model.train()  #1
    model.train_percent = epoch / num_epochs
    total_training_loss = 0
    for batch in train_data_loader:
        inputs = batch['inputs'].cuda(non_blocking=True)
        loss, _, _, _, _ = model.calculate_loss(inputs, 
                                                is_train=True, 
                                                return_logits=True)
        loss.backward()  #2
        optimizer.step() 
        optimizer.zero_grad()  #3
        total_training_loss += loss.item()

    if training_scheduler is not None:
        training_scheduler.step()

    total_nll, total_kl = 0, 0
    for batch in val_data_loader:
        inputs = batch['inputs'].cuda(non_blocking=True)
        loss_nll, loss_kl, _, _ = model.calculate_loss(inputs,
                                                      is_train=False, 
                                                      teacher_forcing=True, 
                                                      return_logits=True)
        total_kl += loss_kl.sum().item()
        total_nll += loss_nll.sum().item()

    total_kl /= len(val_data)
    total_nll /= len(val_data)
    total_loss = total_kl + total_nll
    tuning_loss = total_nll 

    if tuning_loss < best_val_result:
        best_val_epoch, best_val_result = epoch, tuning_loss
  • #1 训练循环
  • #2 更新权重
  • #3 为验证传递清零梯度

我们将使用学习率为0.0005、每500次前向传播将学习率降低一半的学习率调度器、批量大小为8来训练50个epoch。大部分训练基于我们在清单6.14中定义的calculate_loss方法。我们发现,随着验证损失的下降,模型损失也下降,基于负对数似然(nll)的验证损失达到了1.21。这个结果不错,但我们需要看看它在测试数据上的表现,特别是在需要预测多个时间步的情况下。为此,我们需要定义一个新函数,如下所示。

清单 6.22 评估未来预测

def eval_forward_prediction(model, 
  dataset, 
  burn_in, 
  forward_steps, 
  gpu=True, batch_size=8, 
  return_total_errors=False):

  dataset.return_edges = False

  data_loader = DataLoader(
    dataset, batch_size=batch_size, pin_memory=gpu)
  model.eval()
  total_se = 0
  batch_count = 0
  all_errors = []

  for batch_ind, batch in enumerate(data_loader):
    inputs = batch['inputs']
    with torch.no_grad():
      model_inputs = inputs[:, :burn_in]
      gt_predictions = inputs[:, burn_in:burn_in+forward_steps]
      model_inputs = model_inputs.cuda(non_blocking=True)
      model_preds = model.predict_future(
          model_inputs,
          forward_pred_steps
          ).cpu()
      batch_count += 1
      if return_total_errors:
          all_errors.append(
            F.mse_loss(
              model_preds, 
              gt_predictions,
              reduction='none'
             ).view(
               model_preds.size(0), 
               model_preds.size(1), -1
             ).mean(dim=-1)
          )
      else:
          total_se += F.mse_loss(
            model_preds, 
            gt_predictions,
            reduction='none'
          ).view(
            model_preds.size(0),
            model_preds.size(1),
            -1
          ).mean(dim=-1).sum(dim=0)

  if return_total_errors:
         return torch.cat(all_errors, dim=0)
     else:
            return total_se / len(dataset)

这个函数加载我们的测试数据,然后计算不同时间范围内的MSE(均方误差)来评估预测的准确度。当我们测试模型时,发现它能够预测下一个时间步,MSE为0.00008。更令人印象深刻的是,它能准确预测40个时间步,准确率达到94%。这远远优于我们的LSTM和GAT模型,后者分别达到了65%和55%。未来时间步数的准确度下降情况见图6.15,示例输出见图6.16。

image.png

image.png

我们已经涵盖了NRI模型的所有核心组件,完整的工作代码可以在GitHub仓库中找到(mng.bz/4a8D)。该模型的准确性令人印象深刻,展示了将生成方法和基于图的技术与时间模型结合的强大能力。这一点在图6.15中得到了体现,我们可以看到预测的姿态与结果估计的姿态非常接近。

此外,这种方法不仅能预测图形,还能在没有全部图数据的情况下学习潜在的结构。在这个问题中,我们知道预期的交互网络。然而,在许多情况下,我们并不知道交互网络。例如,在受限空间内移动的粒子。当它们处于某个交互半径内时,它们会相互影响,但当它们远离时则不会。这种情况不仅适用于粒子,还适用于从细胞到运动员等生物体。事实上,世界上的大多数情况都涉及相互作用的代理,它们之间的交互网络往往是隐秘的。NRI模型不仅提供了一种预测这些代理行为和运动的工具,还能够学习它们与其他代理之间的交互模式。实际上,原始的NRI论文通过使用篮球比赛的视频跟踪数据演示了这一点,并显示该模型能够学习球、持球人、屏障者和防守对位球员之间的典型模式。(更多信息请参阅Kipf等人[2])。

6.5 深入探讨

在本章中,我们展示了如何处理时间序列或动态问题。这里,我们将进一步详细探讨我们所使用的一些关键模型组件。

6.5.1 循环神经网络(RNN)

在图6.16中,我们展示了RNN模型的示意图。与我们所见的所有其他模型相比,RNN模型的主要区别在于它能够处理顺序数据。这意味着每个时间步都有一个隐藏层,并且该隐藏层的输出与后续时间步的输入结合。在图6.17中,这一点通过两种方式展示。首先,在左侧,我们将时间更新显示为一个单一的自循环,记作Whh。为了更好地理解这个自循环的作用,我们将模型在时间上“展开”,以便显式地看到模型是如何更新的。在这里,我们将输入、输出和隐藏层(x、y、h)更改为时间变量(xt、yt、ht)。在初始步骤t,我们通过输入数据xt和上一隐藏层ht-1的权重来更新当前的隐藏层,然后使用这个更新的隐藏层来输出yt。然后,ht的权重被传递到ht+1,并与xt+1的新输入一起用于推断yt+1。

这个模型的一个关键特点是,当我们通过反向传播更新权重时,我们需要进行时间反向传播(BPTT)。这是所有RNN的特有功能。然而,大多数现代深度学习框架使得这一过程变得非常简单,并且隐藏了对实践者来说困难的计算细节。

image.png

让我们看看如何使用PyTorch实现一个RNN。这个过程非常简单,只需定义一个神经网络类,然后在网络中引入特定的RNN层。例如,在清单6.23中,我们展示了定义一个包含单个RNN层的网络的代码。由于这里只有一个隐藏层,这个定义是非常基础的。然而,看到这个示例对于理解模型如何训练非常有帮助。在每个时间步,我们的输入同时传递到隐藏层和输出层。当我们执行前向传播时,输出会返回到输出层和隐藏层。最后,我们需要用一些东西来初始化我们的隐藏层,因此我们使用了一个全连接层。

清单 6.23 定义一个RNN

class PoseEstimationRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super(PoseEstimationRNN, self).__init__()

        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.rnn = nn.RNN(
            input_size, hidden_size, 
            num_layers, batch_first=True)  #1
        self.fc = nn.Linear(hidden_size, output_size)  #2

    def forward(self, x):    
        h0 = torch.zeros(self.num_layers,  #3
                         x.size(0), self.hidden_size)  #3
        h0 = h0.to(x.device)

        out, _ = self.rnn(x, h0)  #4
        out = self.fc(out[:, -10:, :])  #5
        return out
#1 RNN层
#2 全连接层
#3 设置初始的隐藏状态和细胞状态
#4 前向传播RNN
#5 将最后一个时间步的输出传递到全连接层

在实践中,我们通常会使用更复杂的RNN。包括对RNN的扩展,如LSTM网络或GRU网络。我们甚至可以使用深度学习库将RNN、LSTM和GRU堆叠在一起。GRU与RNN相似,适用于数据序列。它们特别设计来解决RNN的一个关键缺点——梯度消失问题。GRU使用两个门,分别决定保留多少过去的信息(更新门)和忘记或丢弃多少信息(重置门)。我们在图6.18中展示了GRU的一个示例设计。在这里,zt表示更新门,rt表示重置门。~ht项被称为候选激活,反映了表示的新状态的候选,而ht项是实际的隐藏状态。

image.png

在清单6.24中,我们展示了如何构建一个带有GRU层的模型。在这里,大部分实现工作由PyTorch处理,GRU层从标准的PyTorch库中导入。模型的其余部分是一个典型的神经网络定义。

清单 6.24 GRU

class PoseEstimationGRU(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super(PoseEstimationGRU, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.gru = nn.GRU(
            input_size, hidden_size, 
            num_layers, batch_first=True)  #1
        self.fc = nn.Linear(hidden_size, output_size)  #2

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)  #3
        h0 = h0.to(x.device)  #3
        out, _ = self.gru(x, h0)  #4
        out = self.fc(out[:, -10:, :])  #5
        return out
#1 GRU层
#2 全连接层
#3 设置初始隐藏状态
#4 前向传播GRU
#5 将最后一个时间步的输出传递到全连接层

6.5.2 时间邻接矩阵

在考虑时间图时,我们可能从两个节点通过一条边连接开始,然后在每个后续的时间步中添加其他几个节点和/或边。这会导致几个不同的图,每个图都有不同大小的邻接矩阵。

这可能会在设计我们的GNN时带来困难。首先,在每个时间步,我们有不同大小的图。这意味着我们不能使用节点嵌入,因为节点的数量会在输入数据中不断变化。一种方法是使用每个时间步的图嵌入,将整个图作为低维表示存储。这种方法是许多时间序列方法的核心,其中图嵌入随着时间演变,而不是实际的图。我们甚至可以对图进行更复杂的变换,例如在我们的NRI模型中使用自编码器模型。

另一种方法是通过创建时间邻接矩阵,将每个时间步的所有个别图转换为一个单一的更大图。这涉及将每个时间步包装成一个图,该图跨越每个时间步的数据以及动态时间数据。如果图较小且我们只对未来几个时间步感兴趣,时间邻接矩阵可能会非常有用。然而,它们通常会变得非常大且难以处理。另一方面,使用时间嵌入方法往往涉及多个复杂的子组件,且训练起来可能非常困难。不幸的是,没有一种通用的时间图,最佳方法几乎总是特定于问题的。

6.5.3 将自编码器与RNN结合

在这一节中,为了构建关于NRI模型的直觉,我们将总结它的组件,并说明它在预测图结构和节点轨迹中的应用。首先,在图6.19中,我们重复了NRI模型的示意图。

image.png

在这个模型中,有两个关键组件。首先,我们训练一个编码器,将每个时间帧的图形编码到潜在空间中。明确地说,我们使用编码器预测潜在交互(z)上的概率分布 qj(z∣x)q_j(z|x)qj​(z∣x),给定初始图形(x)。一旦我们训练好了编码器,我们就使用解码器将来自该概率分布的样本转换为轨迹,使用潜在编码以及前几个时间步的数据。在实践中,我们使用编码器-解码器结构来推断具有不同交互类型(或边)的节点轨迹。

在本章中,我们只考虑了两种边类型:传感器之间是否存在物理连接。然而,这种方法可以扩展到考虑许多不同的连接,且这些连接随着时间变化。此外,解码器模型需要一个RNN,以有效捕捉我们图中的时间数据。为了构建对NRI模型的直觉,让我们再次重复这个过程。

输入 — 节点数据。

编码

  • 编码器接收节点数据。
  • 编码器将节点数据转换为边数据。
  • 编码器在潜在空间中表示边数据。

潜在空间 — 潜在空间表示不同边类型的概率。在这里,我们有两种边类型(连接和未连接),尽管对于更复杂的关系,可能存在多种边类型。我们始终需要至少两种类型的边,否则模型会假设所有节点都连接在一起,或者更糟糕的是,所有节点都没有连接。

解码

  • 解码器从潜在空间中获取边类型概率。
  • 解码器根据这些概率学习重建未来的图形状态。

预测 — 模型通过学习预测图连接性来预测未来的轨迹。

请注意,这个模型同时为我们提供了图和轨迹的预测!虽然这对我们的问题可能没有帮助,但对于我们不知道底层图结构的情况(例如社交媒体网络或运动队),这可以为我们提供发现系统中新交互模式的方法。

6.5.4 Gumbel-Softmax

在NRI模型中,在计算这两种损失之前,有一个额外的步骤,即使用Gumbel-Softmax计算边的概率。我们需要引入Gumbel-Softmax的关键原因是,我们的自编码器学习预测表示我们边的邻接矩阵,也就是网络连接性,而不是节点及其特征。因此,自编码器的最终预测必须是离散的。然而,我们同时也在推断一个概率。每当概率数据需要变为离散时,Gumbel-Softmax是一个常用的方法。

在这里,我们有两种离散类型的边,也就是是否连接。这意味着我们的数据是分类的——每条边要么属于类别0(未连接),要么属于类别1(已连接)。Gumbel-Softmax用于从一个分类分布中绘制并评分样本。在实践中,Gumbel-Softmax将近似我们编码器的输出,这些输出以对数概率或logits的形式给出,作为Gumbel分布,这是一种极值分布。它将我们的数据的连续分布近似为离散分布(边类型),然后我们可以对这个分布应用损失函数。

Gumbel分布的温度(我们的一种超参数)反映了分布的“尖锐度”,类似于方差控制高斯分布的尖锐度。在本章中,我们使用了0.5的温度,这大约是中等尖锐度。我们还指定了“Hard”作为超参数,这表示是否存在一个或多个类别。如前所述,我们希望在训练时有两个类别,用来表示边是否存在。这允许我们将分布近似为连续分布,然后我们可以将其通过网络反向传播作为损失。然而,在测试时,我们可以将Hard设置为True,这意味着只有一个类别。这使得分布完全离散,意味着我们不能使用损失进行优化,因为离散变量按定义是不可微分的。这是一个有用的控制,以确保我们的测试循环不会传播任何梯度。

总结

虽然一些系统可以使用单一的数据快照来进行预测,但其他系统则需要考虑时间变化,以避免错误或漏洞。 时空图神经网络(Spatiotemporal GNNs)考虑了前几个时间步,以建模图在时间上的演化。 时空GNNs可以解决姿态估计问题,即根据一些关于身体位置如何变化的数据,预测身体的下一个位置。在这种情况下,节点代表放置在身体关节上的传感器,边代表关节之间的身体连接。 邻接矩阵可以通过沿对角线连接不同的邻接矩阵来适应时间信息。 可以将记忆引入到模型中,包括图神经网络(GNNs),例如使用递归神经网络(RNN)或门控递归单元网络(GRU)。 神经关系推断(NRI)模型将递归网络(如GRU)与自编码器GNN结合。这些模型可以推断时间模式,即使邻接信息未知。