图神经网络实战——图自编码器

694 阅读54分钟

本章内容包括:

  • 区分判别模型和生成模型
  • 将自编码器和变分自编码器应用于图
  • 使用 PyTorch Geometric 构建图自编码器
  • 过度压缩和图神经网络
  • 链接预测和图生成

到目前为止,我们已经讨论了如何将经典的深度学习架构扩展到图结构数据。在第 3 章中,我们讨论了卷积图神经网络(GNNs),它们应用卷积运算符来识别数据中的模式。在第 4 章中,我们探索了注意力机制以及它如何用于提高图学习任务的性能,例如节点分类。

卷积 GNN 和注意力 GNN 都是判别模型的例子,因为它们学习区分数据的不同实例,例如一张照片是猫还是狗。在本章中,我们介绍了生成模型的话题,并通过两种最常见的架构——自编码器和变分自编码器(VAE)来探讨它们。生成模型旨在学习整个数据空间,而不是像判别模型那样划分数据空间中的边界。例如,生成模型学习如何生成猫和狗的图像(学习重现猫或狗的某些方面,而不是仅仅学习区分两个或更多类别的特征,例如猫的尖耳朵或猎犬的长耳朵)。

正如我们将发现的那样,判别模型学习在数据空间中划分边界,而生成模型学习建模数据空间本身。通过逼近数据空间,我们可以从生成模型中采样,创建新的训练数据示例。在前面的例子中,我们可以使用我们的生成模型生成新的猫或狗的图像,甚至是具有两者特征的混合版本。这是一个非常强大的工具,对于初学者和资深数据科学家来说都是重要的知识。近年来,深度生成模型(使用人工神经网络的生成模型)在许多语言和视觉任务中展现出了惊人的能力。例如,DALL-E 系列模型能够根据文本提示生成新的图像,而像 OpenAI 的 GPT 模型这样的模型则极大地改变了聊天机器人的能力。

在本章中,您将学习如何扩展生成架构以处理图结构数据,从而导致图自编码器(GAEs)和变分图自编码器(VGAEs)。这些模型与前几章中的判别模型不同。正如我们所看到的,生成模型建模整个数据空间,并且可以与判别模型结合用于下游的机器学习任务。

为了展示生成方法在学习任务中的强大能力,我们回到第 3 章中介绍的亚马逊产品共同购买网络。然而,在第 3 章中,您学习了如何根据项目在网络中的位置预测它可能属于哪个类别。在本章中,我们将展示如何根据项目的描述预测它应该在网络中的位置。这被称为边(或链接)预测,这在设计推荐系统时经常出现。我们将在这里利用我们对 GAE 的理解来执行边预测,构建一个能够预测图中节点连接的模型。我们还将讨论过度压缩的问题,这是 GNN 的一个特定考虑因素,以及如何应用 GNN 来生成潜在的化学图。

到本章结束时,您应该了解何时以及在哪里使用图的生成模型(而不是判别模型),并且知道在需要时如何实现它们。

注意:本章的代码可以在 GitHub 仓库(mng.bz/4aGQ)中以笔记本形式找到。本章的 Colab 链接和数据也可以在同一位置访问。

5.1 生成模型:学习如何生成

深度学习的经典例子是,给定一组带标签的图像,如何训练模型来学习如何给新出现的、未见过的图像贴上标签。如果我们考虑一个包含船只和飞机的图像集,我们希望模型能够区分这些不同的图像。如果我们给模型输入一张新图像,我们希望它能正确地识别它,例如将其识别为船只。判别模型学习基于特定目标标签区分不同类别。卷积架构(第 3 章讨论)和基于注意力的架构(第 4 章讨论)通常用于创建判别模型。然而,正如我们所看到的,它们也可以被融入到生成模型中。为了理解这一点,我们首先需要了解判别模型和生成模型之间的区别。

5.1.1 生成模型与判别模型

如前几章所述,我们用来训练模型的原始数据集被称为训练数据,而我们希望预测的标签是训练目标。未见过的数据是我们的测试数据,我们希望学习训练数据中的目标标签,以便对测试数据进行分类。另一种描述方法是使用条件概率。我们希望模型返回在给定数据实例 X 的情况下某个目标 Y 的概率。我们可以将其写为 P(Y|X),其中竖线表示 Y 是“基于”X的。

正如我们所说,判别模型学习区分类别。这等同于学习数据在数据空间中的分隔边界。相比之下,生成模型学习建模数据空间本身。它们捕捉数据空间中数据的整体分布,并且在给定新示例时,告诉我们这个新示例的可能性。使用概率语言,我们可以说它们建模数据和目标之间的联合概率 P(X,Y)。一个典型的生成模型的例子可能是用于预测句子中下一个词的模型(例如,许多现代手机中的自动补全功能)。生成模型会为每个可能的下一个词分配一个概率,并返回那些具有最高概率的词。判别模型可以告诉你一个词具有某种特定的情感,而生成模型则会建议一个词来使用。

回到我们的图像示例,生成模型近似图像的整体分布。这可以在图 5.1 中看到,生成模型已经学习了数据空间中点的位置(而不是它们如何被分隔)。这意味着生成模型必须比判别模型学习更复杂的数据相关性。例如,生成模型学习到“飞机有翅膀”和“船只出现在水面附近”。另一方面,判别模型只需要学习“船”与“非船”之间的区别。它们可以通过查找图像中的标志性特征,如桅杆、龙骨或帆,来做到这一点。然后,它们可以在很大程度上忽略图像的其余部分。因此,生成模型的训练可能计算开销更大,并且可能需要更大的网络架构。(在第 5.5 节中,我们将描述过度压缩问题,这是大型 GNN 的一个特定问题。)

image.png

5.1.2 合成数据

考虑到判别模型在训练时计算开销较小,并且比生成模型更能抵抗离群值,你可能会想,为什么我们还要使用生成模型。然而,生成模型在数据标注相对昂贵但生成数据集比较容易时,是一种高效的工具。例如,生成模型在药物发现中越来越多地被使用,它们生成可能具有某些属性的新候选药物,例如能够减轻某些疾病的效果。从某种意义上说,生成模型试图学习如何创建合成数据,这使我们能够生成新的数据实例。例如,图 5.2 中显示的所有人都不存在,而是通过从数据空间中采样,并使用生成模型进行逼近后创建的。

image.png

生成模型创建的合成示例可以用来扩充数据集,尤其是在收集数据集时非常昂贵的情况下。与其在各种条件下拍摄大量面部照片,不如使用生成模型创建新的数据示例(例如,戴着帽子、眼镜和口罩的人),以增加数据集,包含一些棘手的边缘案例。这些合成示例可以进一步用来改善我们的其他模型(例如,识别某人是否戴口罩的模型)。然而,在引入合成数据时,我们还必须小心避免将其他偏差或噪声引入我们的数据集。

此外,判别模型通常用于生成模型的下游。这是因为生成模型通常以“自监督”的方式进行训练,无需依赖数据标签。它们学习将复杂的高维数据压缩(或编码)为低维数据。这些低维表示可以帮助我们更好地发掘数据中的潜在模式。这被称为维度约简,通常在数据聚类或分类任务中非常有用。稍后,我们将看到生成模型如何在从未看到标签的情况下将图分为不同的类别。在每个数据点标注昂贵的情况下,生成模型可以大大节省成本。接下来,我们将介绍我们的第一个生成 GNN 模型。

5.2 图自编码器用于链接预测

深度生成模型的一个基础且流行的模型是自编码器。自编码器框架之所以被广泛使用,是因为它具有极高的适应性。正如第 3 章中提到的注意力机制可以用来改善许多不同的模型一样,自编码器也可以与许多不同的模型结合使用,包括不同类型的 GNN。当我们理解了自编码器的结构后,编码器和解码器可以替换为任何类型的神经网络,包括不同的 GNN,如第 2 章中的图卷积网络(GCN)和 GraphSAGE 架构。

然而,在将自编码器应用于图数据时,我们需要小心。当重构我们的数据时,我们也必须重构邻接矩阵。在本节中,我们将使用第 3 章中的亚马逊产品数据集 [2] 来实现一个 GAE。我们将为链接预测任务构建一个 GAE,这是处理图时常见的问题。这样,我们就可以重构邻接矩阵,这在处理缺失数据的数据集时尤为有用。我们将遵循以下过程:

  1. 定义模型:

    • 创建编码器和解码器。
    • 使用编码器创建一个潜在空间进行采样。
  2. 定义训练和测试循环,包含适用于构建生成模型的损失函数。

  3. 将数据准备为图形,包含边列表和节点特征。

  4. 训练模型,将边数据传入计算损失。

  5. 使用测试数据集测试模型。

5.2.1 回顾第 3 章中的亚马逊产品数据集

在第 3 章中,我们学习了包含共同购买信息的亚马逊产品数据集。这个数据集包含了关于各种不同商品的信息,包括谁购买了这些商品、如何购买的详细信息,以及商品的类别,这些类别就是第 3 章中的标签。我们已经了解了如何将这个表格数据集转化为图结构,从而使我们的学习算法更加高效和强大。我们还在不知不觉中使用了一些维度约简。主成分分析(PCA)被应用于亚马逊产品数据集来创建特征。每个产品描述通过词袋算法转换为数值,接着应用 PCA 将(现在是数值化的)描述降至 100 个特征。

在本章中,我们将重新回顾亚马逊产品数据集,但这次我们有不同的目标。我们将使用这个数据集来学习链接预测。实际上,这意味着学习图中节点之间的关系。这有许多应用场景,例如预测用户接下来会喜欢看哪些电影或电视节目,建议社交媒体平台上的新连接,甚至预测哪些客户更可能违约。在这里,我们将用它来预测亚马逊电子产品数据集中应该连接在一起的商品,如图 5.3 所示。有关链接预测的更多细节,请参见本章末尾的第 5.5 节。

image.png

与所有数据科学项目一样,首先查看数据集并理解问题是什么是非常值得的。我们从加载数据开始,就像我们在第 3 章中做的那样,代码见列表 5.1。数据已经预处理和标记,可以使用 NumPy 加载。有关数据集的更多细节,可以参见 [2]。

列表 5.1 加载数据

import numpy as np

filename = 'data/new_AMZN_electronics.npz'

data = np.load(filename)

loader = dict(data)
print(loader)

上述输出将打印如下内容:

{'adj_data': array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
 'attr_data': array([[0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 1., 0., ..., 0., 0., 0.],
       [1., 1., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 1.]], dtype=float32),
 'labels': array([6, 4, 3, ..., 1, 2, 3]),
 'class_names': array(['Film Photography', 'Digital Cameras', 'Binoculars & Scopes',
       'Lenses', 'Tripods & Monopods', 'Video Surveillance',
       'Lighting & Studio', 'Flashes'], dtype='<U19')}

数据加载后,我们可以接着查看数据的一些基本统计信息和细节。我们感兴趣的是边或链接预测,因此了解不同边的数量是很有价值的。我们也可能想知道图中有多少个组件以及平均度数,以了解图的连接性。我们展示了计算这些的代码,如下列表所示。

列表 5.2 探索性数据分析

adj_matrix = torch.tensor(loader['adj_data'])
if not adj_matrix.is_sparse:
    adj_matrix = adj_matrix.to_sparse()

feature_matrix = torch.tensor(loader['attr_data'])
labels = loader['labels']

class_names = loader.get('class_names')
metadata = loader.get('metadata')

num_nodes = adj_matrix.size(0)
num_edges = adj_matrix.coalesce().values().size(0)   #1
density = num_edges / (num_nodes * (num_nodes - 1) / 2) if num_nodes > 1 else 0  #2

#1 这是因为邻接矩阵是无向的。
#2 实际边数与可能边数的比例。

我们还绘制了度分布图,以查看连接的变化,如下列表和图 5.4 所示。

列表 5.3 绘制图形

degrees = adj_matrix.coalesce().indices().numpy()[0]    #1
degree_count = np.bincount(degrees, minlength=num_nodes)

plt.figure(figsize=(10, 5))
plt.hist(degree_count, bins=25, alpha=0.75, color='blue')
plt.xlabel('Degree')
plt.ylabel('Frequency')
plt.grid(True)
plt.show()

#1 获取每个非零值的行索引。

我们发现图中有 7,650 个节点,超过 143,000 条边,整体密度为 0.0049。因此,我们的图是中等大小的(约 10,000 个节点),但非常稀疏(密度远小于 0.05)。我们看到大多数节点的度数较低(少于 10),但还有第二个度数较高的峰值(约 30),并且有一个较长的尾部。总体来说,我们看到很少有度数很高的节点,这与图的低密度是相符的。

image.png

5.2.2 定义图自编码器

接下来,我们将使用生成模型——自编码器来估计和预测亚马逊电子产品数据集中的链接。在此过程中,我们并不孤单,因为链接预测正是 2012 年 Kipf 和 Welling 首次发布图自编码器(GAE)时应用的任务 [3]。在他们的开创性论文中,他们介绍了 GAE 及其变分扩展,我们稍后将讨论,并将这些模型应用于图深度学习中的三个经典基准数据集:Cora 数据集、CiteSeer 和 PubMed。今天,大多数图深度学习库使得创建并开始训练 GAE 变得非常简单,因为这些模型已经成为最流行的基于图的深度生成模型之一。在本节中,我们将详细了解构建一个图自编码器所需的步骤。

GAE 模型类似于典型的自编码器。唯一的区别是我们网络中的每一层都是一个图神经网络(GNN),例如 GCN 或 GraphSAGE 网络。在图 5.5 中,我们展示了 GAE 架构的示意图。大体来说,我们将使用编码器网络将边数据压缩为低维表示。

image.png

我们为 GAE 定义的第一个部分是编码器,它将接收数据并将其转换为潜在表示。实现编码器的代码片段在列表 5.4 中给出。我们首先导入库,然后构建一个 GNN,其中每一层的规模逐渐减小。

列表 5.4 图编码器

from torch_geometric.nn import GCNConv   #1

class GCNEncoder(torch.nn.Module):                         #2
    def __init__(self, input_size, layers, latent_dim):    #2
        super().__init__() #2
        self.conv0 = GCNConv(input_size, layers[0])    #3
        self.conv1 = GCNConv(layers[0], layers[1])     #3
        self.conv2 = GCNConv(layers[1], latent_dim)    #3

    def forward(self, x, edge_index):           #4
        x = self.conv0(x, edge_index).relu()    #4
        x = self.conv1(x, edge_index).relu()    #4
        return self.conv2(x, edge_index)        #4

#1 从 PyG 加载 GCNConv 模型
#2 定义编码器层,并使用预定义的大小初始化
#3 定义每个编码器层的网络
#4 编码器的前向传递,包含边数据

注意,我们还必须确保前向传递能够返回图中的边数据,因为我们将使用自编码器从潜在空间重建图。换句话说,自编码器将学习如何从我们特征空间的低维表示重建邻接矩阵。这意味着它也在学习如何从新数据中预测边。为了做到这一点,我们需要修改自编码器结构,使其学习重建边,特别是通过修改解码器。在这里,我们将使用内积来从潜在空间预测边。这个过程在列表 5.5 中展示。(要理解为什么使用内积,请参见第 5.5 节中的技术细节。)

列表 5.5 图解码器

class InnerProductDecoder(torch.nn.Module):     #1
    def __init__(self):                        
         super().__init__()                    

def forward(self, z, edge_index):     #2
        value = (z[edge_index[0]] * \
z[edge_index[1]]).sum(dim=1)   #3
        return torch.sigmoid(value)

#1 定义解码器层
#2 声明解码器的形状和大小(这与编码器相反)
#3 解码器的前向传递

现在,我们准备将编码器和解码器结合到 GAE 类中,该类包含这两个子模型(见列表 5.6)。请注意,我们现在没有为解码器初始化任何输入或输出大小,因为它只是将内积应用于编码器的输出和边数据。

列表 5.6 图自编码器

   class GraphAutoEncoder(torch.nn.Module):
        def __init__(self, input_size, layers, latent_dims):
            super().__init__()
            self.encoder = GCNEncoder(input_size, \
   layers, latent_dims)     #1
            self.decoder = InnerProductDecoder()      #2

        def forward(self, x):
            z = self.encoder(x)
            return self.decoder(z)

#1 定义 GAE 的编码器
#2 定义解码器

在 PyTorch Geometric (PyG) 中,GAE 模型可以通过简单地导入 GAE 类来简化,这样一来,一旦传递给编码器,解码器和自编码器就会自动构建。我们将在本章稍后的部分构建 VGAE 时使用此功能。

5.2.3 训练图自编码器执行链接预测

在构建了我们的 GAE 后,我们可以继续使用它来执行亚马逊产品数据集的边预测。整个框架将遵循典型的深度学习问题格式,其中我们首先加载数据,准备数据,并将数据划分为训练集、测试集和验证集;定义我们的训练参数;然后训练和测试我们的模型。这些步骤如图 5.6 所示。

image.png

我们首先加载数据集并为学习算法做好准备,这部分我们已经在第5.1节中完成。为了使用PyG模型进行GAE和VGAE,我们需要从邻接矩阵构造一个边索引,这可以通过PyG的工具函数to_edge_index轻松完成,具体方法在以下代码示例中描述。

Listing 5.7 构造边索引

from torch_geometrics.utils import to_edge_index   #1

edge_index, edge_attr = to_edge_index(adj_matrix)   #2
num_nodes = adj_matrix.size(0)
#1 从PyG工具库加载to_edge_index
#2 将邻接矩阵转换为边索引和边属性向量

接下来,我们加载PyG库并将数据转换为PyG数据对象。我们还可以对数据集应用转换,其中特征和邻接矩阵按照第3章的方式加载。首先,我们对特征进行归一化,然后根据图的边或链接将数据集拆分为训练集、测试集和验证集,如以下代码所示。这个步骤在进行链接预测时至关重要,确保数据正确拆分。在代码中,我们使用了5%的数据作为验证集,10%作为测试集,并注意我们的图是无向图。在这里,我们不添加任何负训练样本。

Listing 5.8 转换为PyG对象

data = Data(x=feature_matrix,         #1
            edge_index=edge_index,   
            edge_attr=edge_attr,     
            y=labels)                

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = T.Compose([\
     T.NormalizeFeatures(),\                  #2
     T.ToDevice(device),                     
     T.RandomLinkSplit(num_val=0.05,\
     num_test=0.1, is_undirected=True,       
     add_negative_train_samples=False)])     
train_data, val_data, test_data = transform(data)
#1 将数据转换为PyG数据对象
#2 转换数据并将链接拆分为训练集、测试集和验证集

一切准备就绪后,我们可以将GAE应用到Amazon Products数据集。首先,我们定义模型、优化器和损失函数。我们使用二元交叉熵损失函数来处理解码器的预测值,并与真实的边索引进行比较,以检查我们的模型是否正确重构了邻接矩阵,具体代码如下。

Listing 5.9 定义模型

input_size, latent_dims = feature_matrix.shape[1], 16   #1
layers = [512, 256]                                    
model = GraphAutoEncoder(input_size, layers, latent_dims)   #2
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()   #3
#1 指定编码器的输入维度
#2 定义一个形状正确的GAE模型
#3 我们的损失函数是二元交叉熵。

使用二元交叉熵损失函数非常重要,因为我们需要计算每个边是否是真正的边的概率,其中真正的边对应于那些未被隐藏、不需要预测的边(即正样本)。编码器学习压缩边数据,但不改变边的数量,而解码器学习预测边。从某种意义上说,我们将判别性和生成性步骤结合在一起。因此,二元交叉熵给出了一个概率,表示这些节点之间可能存在边的概率。它是二元的,因为边要么存在(标签1),要么不存在(标签0)。在每个训练周期中,我们可以将所有二元交叉熵概率大于0.5的边与实际的真实边进行比较,具体代码如下。

Listing 5.10 训练函数

def train(model, criterion, optimizer):

    model.train() 

    optimizer.zero_grad() 
    z = model.encoder(train_data.x,\
    train_data.edge_index)   #1

    neg_edge_index = negative_sampling(\         #2
    edge_index=train_data.edge_index,\
    num_nodes=train_data.num_nodes,             
    num_neg_samples=train_data.\
    edge_label_index.size(1), method='sparse')  

    edge_label_index = torch.cat(                     #3
    [train_data.edge_label_index, neg_edge_index],    #3
    dim=-1,)                                          #3

    out = model.decoder(z, edge_label_index).view(-1)   #4

    edge_label = torch.cat([        #5
    train_data.edge_label,         
train_data.edge_label.new_zeros\
(neg_edge_index.size(1))           
    ], dim=0)                      
    loss = criterion(out, edge_label)   #6
    loss.backward()                     #6
    optimizer.step() #6

    return loss
#1 编码图形数据为潜在表示
#2 进行新的负采样
#3 将新的负样本与边标签索引合并
#4 生成边预测
#5 将边标签与负样本的0标签合并
#6 计算损失并进行反向传播

在这里,我们首先将图形编码为潜在表示。然后,我们进行一轮负采样,每个周期都会抽取新的样本。负采样在训练过程中从不存在的标签中随机选择一个子集,而不是选择现有的正标签,以应对真实标签和不存在标签之间的类别不平衡。一旦我们获得这些新的负样本,我们将它们与原始的边标签索引连接,并将这些数据传递给解码器来重建图形。最后,我们将真实的边标签与负边的0标签连接,并计算预测边与真实边之间的损失。请注意,我们这里没有进行批量学习,而是选择在每个周期内对所有数据进行训练。

我们的测试函数(如Listing 5.11所示)比训练函数要简单得多,因为它不需要执行任何负采样。相反,我们只使用真实和预测的边,并返回一个接收者操作特性(ROC)/曲线下面积(AUC)分数来衡量模型的准确性。回想一下,ROC/AUC曲线的值范围在0到1之间,一个完美的模型,其预测完全正确,AUC将为1。

Listing 5.11 测试函数

from sklearn.metrics import roc_auc_score

@torch.no_grad() 
def test(data):
    model.eval() 
    z = model.encode(data.x, data.edge_index)   #1
    out = model.decode(z, \
    data.edge_label_index).view(-1).sigmoid()   #2
    loss = roc_auc_score(data.edge_label.cpu().numpy(),   #3
                        out.cpu().numpy())                #3
    return loss
#1 将图形编码为潜在表示
#2 使用完整的边标签索引解码图形
#3 计算整体的ROC/AUC分数

在每个时间步,我们将使用验证数据中的所有边数据计算模型的整体成功率。训练完成后,我们使用测试数据计算最终的测试准确性,如下所示。

Listing 5.12 训练循环

best_val_auc = final_test_auc = 0 
for epoch in range(1, 201): 
    loss = train(model, criterion, optimizer)  #1
    val_auc = test(val_data)   #2
    if val_auc > best_val_auc: 
        best_val_auc = val_auc 
test_auc = test(test_data)   #3
#1 执行一个训练步骤
#2 在验证数据上测试我们的更新模型
#3 在测试数据上测试我们的最终模型

我们发现,在200个周期后,我们的准确率超过了83%。更好的是,当我们使用测试集来检查模型如何预测边时,我们得到了86%的准确率。我们可以解释我们的模型表现为:假设所有未来的数据与当前数据集相同,那么模型能够86%的时间为购买者推荐有意义的项目。这是一个很好的结果,并展示了GNN在推荐系统中的有用性。我们还可以使用模型更好地理解数据集的结构,或通过探索新构建的潜在空间来应用额外的分类和特征工程任务。接下来,我们将学习图自动编码器模型的一个常见扩展——VGAE。

5.3 变分图自编码器

自编码器将数据映射到潜在空间中的离散点。为了在训练数据集之外进行采样并生成新的合成数据,我们可以在这些离散点之间进行插值。这正是我们在图5.1中描述的过程,在那里我们生成了未见过的数据组合,例如飞行船。然而,自编码器是确定性的,每个输入都映射到潜在空间中的一个特定点。这可能导致在采样时出现明显的不连续性,进而影响数据生成的性能,导致合成数据不能很好地重现原始数据集。为了改善我们的生成过程,我们需要确保潜在空间结构良好,或者说是有规律的。例如,在图5.7中,我们展示了如何使用Kullback-Leibler散度(KL散度)来重新构建潜在空间,以改善重构效果。

image.png

KL散度是衡量一个概率分布与另一个分布之间差异的度量。它计算了将一个分布(原始数据分布)中的值编码到另一个分布(潜在空间)所需的“额外信息”量。在左侧,数据组(xi)之间重叠较少,这意味着KL散度较高;而在右侧,不同数据组之间有更多的重叠(相似性),这意味着KL散度较低。当构建具有较高KL散度的更规则的潜在空间时,我们可以得到非常好的重构效果,但插值效果较差;而当KL散度较低时,则正好相反。关于这一点的更多细节将在第5.5节中提供。

“规则”意味着空间满足两个属性:连续性和紧凑性。连续性意味着潜在空间中相邻的点被解码为大致相似的内容,而紧凑性意味着潜在空间中的任何点都应该导致一个有意义的解码表示。术语“近似相似”和“有意义”有明确的定义,您可以在《Learn Generative AI with PyTorch》(Manning, 2024; mng.bz/AQBg)中阅读更多内容。然而,在本章中,您需要知道的是,这些属性使得从潜在空间进行采样更容易,从而生成更干净的样本,并可能提高模型的准确性。

当我们对潜在空间进行正则化时,我们使用变分方法,该方法通过概率分布(或密度)来建模整个数据空间。正如我们所看到的,使用变分方法的主要好处是潜在空间的结构良好。然而,变分方法并不一定保证更高的性能,因此在使用这些模型时,通常需要测试自编码器和变分对偶体。这可以通过查看测试数据集上的重构得分(例如,均方误差)、对潜在编码应用某些维度缩减方法(例如,t-SNE或统一流形近似与投影[UMAP])或使用任务特定的度量(例如,图像的Inception Score或文本生成的ROUGE/METEOR)来实现。特别对于图形,可以使用最大均值差异(MMD)、图形统计或图形核方法等度量,来与不同的合成生成图副本进行比较。

在接下来的几节中,我们将更详细地讨论将数据空间建模为概率密度的意义,以及如何通过几行代码将我们的图自编码器转换为VGAE。这些依赖于一些关键的概率机器学习概念,例如KL散度和重参数化技巧,我们将在第5.5节中概述这些概念。为了深入了解这些概念,我们推荐《Probabilistic Deep Learning》(Manning, 2020)。让我们构建一个VGAE架构,并将其应用于与之前相同的Amazon Products数据集。

5.3.1 构建变分图自编码器

VGAE架构与GAE模型类似。主要的区别在于,变分图编码器的输出是通过从概率密度中采样生成的。我们可以通过其均值和方差来表征密度。因此,编码器的输出将是我们之前空间的每个维度的均值和方差。解码器随后将采样的潜在表示解码为与输入数据相似的内容。如图5.8所示,高层次的模型是,我们现在将之前的自编码器扩展为输出均值和方差,而不是来自潜在空间的点估计。这使得我们的模型能够从潜在空间中进行概率采样。

image.png

我们必须调整架构,并且改变我们的损失函数,以包括一个额外的项来正则化潜在空间。Listing 5.13 提供了VGAE的代码片段。Listing 5.4与Listing 5.13中的VariationalGCNEncoder层之间的相似之处包括:我们已经将潜在空间的维度加倍,并且现在在前向传递结束时返回编码器的均值和对数方差。

Listing 5.13 变分GCN编码器(VariationalGCNEncoder)

class VariationalGCNEncoder(torch.nn.Module):            #1
  def __init__(self, input_size, layers, latent_dims):
    super().__init__()
    self.layer0 = GCNConv(input_size, layers[0])
    self.layer1 = GCNConv(layers[0], layers[1])
    self.mu = GCNConv(layers[1], latent_dims)           
    self.logvar = GCNConv(layers[1], latent_dims)       

  def forward(self, x, edge_index):
    x = self.layer0(x, edge_index).relu()
    x = self.layer1(x, edge_index).relu()
    mu = self.mu(x, edge_index)
    logvar = self.logvar(x, edge_index)
    return mu, logvar                      #2
#1 增加了均值和对数方差变量以进行采样
#2 前向传递返回均值和对数方差变量

当我们讨论GAE时,我们了解到解码器使用内积来返回邻接矩阵或边列表。之前我们明确地实现了内积。然而,在PyG中,这个功能是内建的。为了构建一个VGAE结构,我们可以调用VGAE函数,如下所示。

Listing 5.14 变分图自编码器(VGAE)

from torch_geometric.nn import VGAE   #1
model = VGAE(VariationalGCNEncoder(input_size,\
 layers, latent_dims))
#1 使用PyG库中的VGAE函数构建自编码器

这个功能使得构建VGAE变得更加简单,PyG中的VGAE函数处理了重参数化技巧。现在我们有了VGAE模型,接下来我们需要做的是修改训练和测试函数,以包含KL散度损失。训练函数如下所示。

Listing 5.15 训练函数

def train(model, criterion, optimizer):
    model.train() 
    optimizer.zero_grad() 
    z = model.encode(train_data.x, train_data.edge_index)      #1

    neg_edge_index = negative_sampling( 
    edge_index=train_data.edge_index, num_nodes=train_data.num_nodes,
    num_neg_samples=train_data.edge_label_index.size(1), method='sparse')

    edge_label_index = torch.cat( 
    [train_data.edge_label_index, neg_edge_index], 
    dim=-1,) 
    out = model.decode(z, edge_label_index).view(-1)          

    edge_label = torch.cat([ 
    train_data.edge_label,
    train_data.edge_label.new_zeros(neg_edge_index.size(1))
    ], dim=0)

    loss = criterion(out, edge_label)            #2
+ (1 / train_data.num_nodes) * model.kl_loss()  

    loss.backward() 
    optimizer.step()

    return loss
#1 由于我们使用的是PyG的VGAE函数,我们需要使用encode和decode方法。
#2 添加了由KL散度给出的正则化项

这与我们在Listing 5.12中用于训练GAE模型的训练循环相同。唯一的区别是,我们在损失中加入了一个额外的项,用于最小化KL散度,并且我们将编码器和解码器方法调用改为encode和decode(这也需要在测试函数中进行更新)。除此之外,训练过程没有改变。需要注意的是,得益于PyG新增的功能,这些更改比我们之前在PyTorch中做的更简单。然而,了解这些额外步骤可以帮助我们更好地理解GAE的底层架构。

现在我们可以将VGAE应用到Amazon Products数据集,并用它来进行边预测,最终得到88%的整体测试准确率。这比我们的GAE准确率稍高。值得注意的是,VGAE不一定会提供更高的准确率。因此,使用这种架构时,您应该始终尝试GAE和VGAE,并进行细致的模型验证。

5.3.2 何时使用变分图自编码器

鉴于VGAE的准确性与GAE相似,重要的是要意识到两种方法的局限性。一般来说,当你想构建一个生成模型,或者当你希望利用数据的某个方面来学习另一个方面时,GAE和VGAE都是非常适合的模型。例如,我们可能想要构建一个基于图的姿势预测模型。我们可以使用GAE和VGAE架构,根据视频镜头预测未来的姿势。(我们将在后面的章节中看到类似的例子。)这样做时,我们使用GAE/VGAE学习一个身体的图,条件是每个身体部位的未来位置。然而,如果我们特别想生成新的数据,例如用于药物发现的新化学图,VGAE通常更好,因为其潜在空间更加结构化。

一般而言,GAE非常适合特定的重构任务,如链接预测或节点分类,而VGAE更适合那些任务需要更大或更多样化合成样本的情况,例如你想生成全新的子图或小图。相比GAE,VGAE通常更适用于基础数据集噪声较多的情况,而GAE则更适合图数据结构清晰的情况,且其计算速度较快。最后,值得注意的是,由于VGAE采用变分方法,它更不容易过拟合,因此可能会更好地泛化。正如总是一样,选择哪种架构取决于手头的问题。

在本章中,我们学习了两种生成模型的示例——GAE和VGAE模型,并了解了如何实现这些模型以处理图结构数据。为了更好地理解如何使用这一模型类,我们将模型应用于边预测任务。然而,这只是应用生成模型的第一步。

在许多需要生成模型的实例中,我们使用连续的自编码器层进一步降低系统的维度,并增强我们的重构能力。在药物发现和化学科学的背景下,GAE允许我们重构邻接矩阵(如我们在这里所做的),以及重构分子类型,甚至是分子的数量。GAE在许多科学和工业中被广泛使用。现在,您也有了尝试这些工具的能力。

在下一节中,我们将演示如何使用VGAE生成具有特定特性的新的图形,例如具有高属性值的新分子,这些分子可能是潜在的药物候选物。

5.4 通过GNN生成图形

到目前为止,我们已经考虑了如何使用图的生成模型来估算节点之间的边。然而,有时我们不仅仅想生成一个节点或一条边,而是整个图。这在尝试理解或预测图级别的数据时尤其重要。在这个例子中,我们将通过使用GAE和VGAE来生成新的潜在分子,这些分子具有某些特性。

图神经网络(GNN)对药物发现领域产生了巨大影响,特别是在新分子或潜在药物的识别上。2020年,使用GNN发现了一种新的抗生素,2021年也发布了一种新的方法,利用GNN识别食品中的致癌物。从那时起,许多其他论文开始使用GNN作为加速药物发现流程的工具。

5.4.1 分子图

我们将考虑先前已筛选用于药物的较小分子,如ZINC数据集中约250,000个单独分子的描述。该数据集中的每个分子包含以下附加数据:

  • 简化分子输入行条目系统(SMILES)——以ASCII格式描述分子结构或分子图。
  • 重要属性——合成可达性评分(SAS)、水-辛醇分配系数(logP),以及最重要的药物相似性量化估算(QED),它突出显示了该分子作为潜在药物的可能性。

为了使我们的GNN模型能够使用该数据集,我们需要将其转换为合适的图结构。在这里,我们将使用PyG来定义我们的模型并运行深度学习过程。因此,我们首先下载数据,然后使用NetworkX将数据集转换为图对象。我们在Listing 5.16中下载数据集,生成如下输出:

smiles     logP     qed     SAS
0     CC(C)(C)c1ccc2occ(CC(=O)Nc3ccccc3F)c2c1
     5.05060     0.702012     2.084095
1     C[C@@H]1CC(Nc2cncc(-c3nncn3C)c2)C[C@@H](C)C1
     3.11370     0.928975     3.432004
2     N#Cc1ccc(-c2ccc(O[C@@H](C(=O)N3CCCC3)c3ccccc3)...
     4.96778     0.599682     2.470633
3     CCOC(=O)[C@@H]1CCCN(C(=O)c2nc
      (-c3ccc(C)cc3)n3c...     
      4.00022     0.690944     2.822753
4     N#CC1=C(SCC(=O)Nc2cccc(Cl)c2)N=C([O-])
      [C@H](C#...     3.60956     0.789027     4.035182

Listing 5.16 创建分子图数据集

import requests
import pandas as pd

def download_file(url, filename):
     response = requests.get(url)
     response.raise_for_status() 
     with open(filename, 'wb') as f:
     f.write(response.content)

url = "https://raw.githubusercontent.com/
aspuru-guzikgroup/chemical_vae/master/models/
zinc_properties/250k_rndm_zinc_drugs_clean_3.csv"
filename = "250k_rndm_zinc_drugs_clean_3.csv"

download_file(url, filename)

df = pd.read_csv(filename)
df["smiles"] = df["smiles"].apply(lambda s: s.replace("\n", ""))

在Listing 5.17中,我们定义了一个函数,将SMILES转换为小图,我们随后用它来创建PyG数据集。我们还向数据集中的每个对象添加了一些附加信息,如可以用于进一步数据探索的重原子数。这里,我们使用了递归SMILES深度优先搜索(DFS)工具包(RDKit)包(<www.rdkit.org/docs/index.…

Listing 5.17 创建分子图数据集

from torch_geometric.data import Data
import torch
from rdkit import Chem

def smiles_to_graph(smiles, qed):
     mol = Chem.MolFromSmiles(smiles)
        if not mol:
             return None

        edges = []
        edge_features = []
        for bond in mol.GetBonds():
             edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
             bond_type = bond.GetBondTypeAsDouble()
             bond_feature = [1 if i == bond_type\
             else 0 for i in range(4)]
             edge_features.append(bond_feature)

        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_features, dtype=torch.float)
        x = torch.tensor([atom.GetAtomicNum()\
 for atom in mol.GetAtoms()], \
 dtype=torch.float).view(-1, 1)

        num_heavy_atoms = mol.GetNumHeavyAtoms()

        return Data(x=x, edge_index=edge_index,\
 edge_attr=edge_attr, \
qed=torch.tensor([qed], \
dtype=torch.float), \
num_heavy_atoms=num_heavy_atoms)

数据集中的一个随机样本显示在图5.9中,突出了我们的分子图的多样性以及它们的小尺寸,每个分子图的节点和边都少于100个。

image.png

5.4.2 识别新的药物候选物

在图5.10中,我们开始看到QED如何随着不同的分子结构而变化。药物发现的主要障碍之一是分子组合的多样性,以及如何知道哪些组合需要合成并测试其药效。这一过程远远早于将药物引入人体、动物(体内)或有时甚至细胞(体外)试验的阶段。即使是评估分子如溶解度等属性,也可能成为挑战,尤其是当我们仅使用分子图时。在这里,我们将专注于预测分子的QED,以了解哪些分子最有可能作为药物使用。为了给出QED如何变化的例子,可以参见图5.10,其中包含四个具有高QED(~0.95)和低QED(~0.12)的分子。我们可以看到这些分子之间的定性差异,例如低QED的分子具有更多的强键。然而,从图形中直接估计QED是一个挑战。为了帮助我们完成这一任务,我们将使用GNN来生成并评估新的潜在药物。

image.png

我们的工作将基于两篇重要的论文,这些论文展示了生成模型如何成为识别新分子的有效工具(Gómez-Bombarelli等人[4]和De Cao等人[5])。具体来说,Gómez-Bombarelli等人表明,通过构建数据空间的平滑表示(即我们在本章前面描述的潜在空间),可以优化以找到具有特定属性的新候选分子。这项工作在很大程度上借鉴了Keras库中的一个等效实现,该实现由Victor Basu在一篇文章中概述[6]。图5.11重现了[5]中的基本思想。

image.png

在图5.11中,我们可以看到底层模型结构是一个自编码器,就像我们在本章中讨论的那些一样。在这里,我们将分子的SMILES作为输入传递给编码器,然后使用它构建不同分子表示的潜在空间。这通过不同颜色的区域表示不同的分子群体来展示。接着,解码器的设计是将潜在空间忠实地转换回原始分子。这类似于我们之前在图5.5中展示的自编码器结构。

除了潜在空间,我们现在还引入了一个附加功能,即预测分子属性。在图5.11中,我们预测的属性也是我们优化的目标属性。因此,通过学习如何将分子和属性(在我们的例子中是QED)编码到潜在空间中,我们可以优化药物发现过程,生成具有高QED的新候选分子。

在我们的例子中,我们将使用VGAE。该模型包括两个损失函数:一个重构损失,用于衡量传递给编码器的原始输入数据与解码器输出之间的差异;以及一个衡量潜在空间结构的损失,我们使用KL散度来实现。

除了这两个损失函数,我们还将添加一个额外的函数:属性预测损失。属性预测损失估算了在通过属性预测模型运行潜在表示后,预测属性与实际属性之间的均方误差(MSE),如图5.11中间所示。

为了训练我们的GNN,我们将之前在Listing 5.15中提供的训练循环进行适应,以包括这些个别的损失函数。该过程在Listing 5.18中展示。在这里,我们使用二元交叉熵(BCE)作为邻接矩阵的重构损失,而属性预测损失仅考虑QED,并可以基于MSE。

Listing 5.18 分子图生成的损失函数

def calculate_loss(self, pred_adj, \
   true_adj, qed_pred, qed_true, mu, logvar):
             adj_loss = F.binary_cross_entropy\
   (pred_adj, true_adj)   #1

             qed_loss = F.mse_loss\
(qed_pred.view(-1), qed_true.view(-1))     #2

             kl_loss = -0.5 * torch.mean\
(torch.sum(1 + logvar - mu.pow(2)\     #3
 - logvar.exp(), dim=1))

             return adj_loss + qed_loss + kl_loss
#1 重构损失
#2 属性预测损失
#3 KL散度损失

5.4.3 使用VGAE生成图形

现在我们有了训练数据和损失函数,我们可以开始思考模型。总体来说,这个模型将类似于本章之前讨论的GAE和VGAE模型。然而,我们需要对模型进行一些细微的修改,以确保它能够很好地应用于当前问题:

  • 使用异构GCN以考虑不同的边类型。
  • 训练解码器生成整个图形。
  • 引入属性预测层。

让我们逐一看看这些修改。

异构GCN

我们生成的小图将包含不同类型的边来连接图中的节点。具体来说,分子之间可能有不同数量的键,如单键、双键、三键,甚至芳香键,这些键与分子形成环状结构有关。具有多种边类型的图称为异构图,因此我们需要使我们的GNN适应异构图。

到目前为止,我们考虑的所有图都是同构的(仅有一种边类型)。在Listing 5.19中,我们展示了如何将第3章中讨论的GCN适配到异构图中。这里,我们明确列出了异构图的一些不同特征。然而,重要的是要注意,许多GNN库已经原生支持异构图的模型。例如,PyG有一个特定的模型类,称为HeteroConv。

Listing 5.19 异构GCN

from torch.nn import Parameter
from torch_geometric.nn import MessagePassing

class HeterogeneousGraphConv(MessagePassing):
    def __init__(self, in_channels, out_channels, num_relations, bias=True):
        super(HeterogeneousGraphConv, self).__init__(aggr='add')  #1
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_relations = num_relations

        self.weight = Parameter(torch.Tensor(num_relations, in_channels, out_channels))  #2
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            torch.nn.init.zeros_(self.bias)

    def forward(self, x, edge_index, edge_type):
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x, edge_type=edge_type)  #3

    def message(self, x_j, edge_type, index, size):  #4
        W = self.weight[edge_type]  #5
        x_j = torch.matmul(x_j.unsqueeze(1), W).squeeze(1)
        return x_j

    def update(self, aggr_out):
        if self.bias is not None:
            aggr_out += self.bias
        return aggr_out
#1 "Add"聚合操作
#2 权重参数
#3 使用edge_type选择权重
#4 x_j形状为[E, in_channels],edge_type形状为[E]
#5 选择相应的权重

通过前面的GNN,我们可以将我们的编码器构建为这些单独的GNN层的组合。这在Listing 5.20中展示,我们遵循与定义边编码器(参见Listing 5.13)时相同的逻辑,只是现在我们将GCN层替换为异构GCN层。由于我们有不同的边类型,我们现在还必须指定不同类型(关系)的数量,并将特定的边类型传递给我们图编码器的前向函数。同样,我们返回对数方差和均值,以确保潜在空间是通过分布而非点样本构建的。

Listing 5.20 小图编码器

class VariationalGCEncoder(torch.nn.Module):
    def __init__(self, input_size, layers, latent_dims, num_relations):
        super().__init__()
        self.layer0 = HeterogeneousGraphConv(input_size, layers[0], num_relations)  #1
        self.layer1 = HeterogeneousGraphConv(layers[0], layers[1], num_relations)
        self.layer2 = HeterogeneousGraphConv(layers[1], latent_dims, num_relations)

    def forward(self, x, edge_index, edge_type):
        x = F.relu(self.layer0(x, edge_index, edge_type))  #2
        x = F.relu(self.layer1(x, edge_index, edge_type))
        mu = self.mu(x, edge_index)
        logvar = self.logvar(x, edge_index)
        return mu, logvar
#1 异构GCNs
#2 前向传递

图解码器

在之前的例子中,我们使用GAE生成并预测单一图中节点之间的边。然而,现在我们希望使用我们的自编码器生成整个图。因此,我们不再仅仅考虑内积解码器来表示图中边的存在,而是解码每个小分子图的邻接矩阵和特征矩阵。这在Listing 5.21中展示。

Listing 5.21 小图解码器

class GraphDecoder(nn.Module):
    def __init__(self, latent_dim, adjacency_shape, feature_shape):
        super(GraphDecoder, self).__init__()

        self.dense1 = nn.Linear(latent_dim, 128)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.1)

        self.dense2 = nn.Linear(128, 256)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.1)

        self.dense3 = nn.Linear(256, 512)
        self.relu3 = nn.ReLU()
        self.dropout3 = nn.Dropout(0.1)

        self.adjacency_output = nn.Linear(512, torch.prod(torch.tensor(adjacency_shape)).item())
        self.feature_output = nn.Linear(512, torch.prod(torch.tensor(feature_shape)).item())

    def forward(self, z):
        x = self.dropout1(self.relu1(self.dense1(z)))
        x = self.dropout2(self.relu2(self.dense2(x)))
        x = self.dropout3(self.relu3(self.dense3(x)))

        adj = self.adjacency_output(x)  #1
        adj = adj.view(-1, *self.adjacency_shape)
        adj = (adj + adj.transpose(-1, -2)) / 2  #2
        adj = F.softmax(adj, dim=-1)  #3

        features = self.feature_output(x)  #4
        features = features.view(-1, *self.feature_shape)
        features = F.softmax(features, dim=-1)  #5

        return adj, features
#1 生成邻接矩阵
#2 对邻接矩阵进行对称化
#3 应用softmax
#4 生成特征
#5 应用softmax

这段代码大部分是典型的解码器风格网络。我们从一个小型网络开始,该网络的维度与使用编码器创建的潜在空间相匹配。然后,我们通过后续的网络层逐步增大图的大小。这里,我们可以使用简单的线性网络,并且包括了用于性能优化的丢弃层。在最后一层,我们将解码器输出重新塑形为邻接矩阵和特征矩阵。我们还确保在应用softmax之前,对邻接矩阵进行对称化。通过将邻接矩阵与其转置相加并除以2,我们确保节点i连接到j,同时j也连接到i。接着,我们应用softmax来归一化邻接矩阵,确保每个节点的所有出边之和为1。在这里,我们也可以做其他选择,如使用最大值、应用阈值或使用sigmoid函数代替softmax。一般来说,平均+softmax是一个不错的选择。

属性预测层

最后,我们只需将编码器和解码器网络结合成一个最终模型,用于分子图生成,如Listing 5.22所示。总体而言,这遵循了与前面Listing 5.14中相同的步骤,我们定义了编码器和解码器,并使用了重参数化技巧。唯一的区别是,我们还包括了一个简单的线性网络,用于预测图的属性,这里是QED。该预测层作用于重参数化后的潜在表示(z)。

Listing 5.22 用于分子图生成的VGAE

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing

class VGAEWithPropertyPrediction(nn.Module):
    def __init__(self, encoder, decoder, latent_dim):
        super(VGAEWithPropertyPrediction, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.property_prediction_layer = nn.Linear(latent_dim, 1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(logvar / 2)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def forward(self, data):
        mu, logvar = self.encoder(data.x, data.edge_index, data.edge_attr)
        z = self.reparameterize(mu, logvar)
        adj_recon, x_recon = self.decoder(z)
        qed_pred = self.property_prediction_layer(z)
        return adj_recon, x_recon, qed_pred, mu, logvar, z

模型的输出包括均值和对数方差,这些被传递给KL散度;重构的邻接矩阵和特征矩阵,被传递给重构损失;以及预测的QED值,这些值用于预测损失。使用这些,我们可以计算网络的损失,并通过反向传播优化网络权重,从而调整生成的图,确保它们具有较高的QED值。接下来,我们将展示如何在训练和测试循环中实现这一目标。

5.4.4 使用GNN生成分子

在上一节中,我们讨论了使用GNN生成分子所需的所有单独部分。现在,我们将这些不同的元素结合起来,展示如何使用GNN创建优化特定属性的新图。在图5.12中,我们展示了生成具有高QED的分子的步骤。这些步骤包括创建合适的图形来表示小分子,将这些图形传递通过我们的自编码器,预测特定的分子特征(如QED),然后重复这些步骤,直到我们能够重新创建具有特定特征的新分子图。

image.png

剩下的关键部分是将我们的损失函数与我们调整过的VGAE模型结合起来。这在Listing 5.23中展示,其中定义了我们的训练循环。这与前面章节和示例中的训练循环类似。主要思想是,我们的模型用于预测图的某些属性。然而,在这里,我们预测的是整个图,具体定义为预测的邻接矩阵(pred_adj)和预测的特征矩阵(pred_feat)。

我们模型的输出和真实数据被传递到我们的损失计算方法中,该方法包含了重构损失、KL散度损失和属性预测损失。最后,我们计算梯度惩罚,作为我们模型的进一步正则化器(在第5.5节中有更详细的定义)。在计算了损失和梯度后,我们通过模型进行反向传播,更新优化器,并返回损失值。

Listing 5.23 分子图生成的训练函数

def train(model, optimizer, data, test=False):
    model.train()
    optimizer.zero_grad()

    pred_adj, pred_feat, pred_qed, mu, logvar, _ = model(data)

    real_adj = create_adjacency_matrix(data.edge_index, data.edge_attr, num_nodes=NUM_ATOMS)
    real_x = data.x
    real_qed = data.qed

    loss = calculate_loss(pred_adj[0], real_adj, pred_qed, real_qed, mu, logvar)   #1

    total_loss = loss

    if not test:
        total_loss.backward()
    optimizer.step()
    return total_loss.item()
#1 计算损失

在训练过程中,我们发现模型的损失逐渐减少,表明模型正在有效地学习如何重建新的分子。我们在图5.13中展示了一些生成的分子。

image.png

为了更好地理解我们潜在空间中预测的QED属性的分布,我们将编码器应用于一个新的数据子集,并查看数据在潜在空间中表示的前两个轴,如图5.14所示。在这里,我们可以看到潜在空间已经被构建成将具有更高QED的分子聚集在一起。因此,通过从该区域周围的区域进行采样,我们可以识别出新的分子进行测试。未来的工作需要验证我们的结果,但作为发现新分子的第一步,我们已经展示了GNN模型可能被用来提出新的、潜在有价值的药物候选物。

image.png

在本章中,我们专注于生成任务,而非传统的判别模型。我们展示了生成模型,如GAE和VGAE,如何用于边预测,学习在信息可能不可用的情况下识别节点之间的连接。接着,我们展示了生成性GNN不仅可以发现图中未知的部分,如节点或边,还可以生成完全新的、复杂的图,当我们将GNN应用于生成具有高QED的新小分子时。这些结果突显了GNN在化学、生命科学以及许多处理个体图的学科中是至关重要的工具。

此外,我们还了解到,GNN在判别任务和生成任务中都非常有用。这里,我们考虑的是小分子图的主题,但GNN也已经应用于知识图谱和小型社交群体。在下一章中,我们将探讨如何通过将生成性GNN与时间编码结合,学习生成在时间上保持一致的图形。在这个过程中,我们更进一步,学习如何教GNN“行走”。

5.5 引擎内部

深度生成模型使用人工神经网络来建模数据空间。深度生成模型的经典例子之一是自编码器。自编码器包含两个关键组件:编码器和解码器,二者均由神经网络表示。它们学习如何将数据编码(压缩)成低维表示,并将其解码(解压)回去。图5.15展示了一个基本的自编码器,它将图像作为输入并进行压缩(步骤1)。这将产生低维表示或潜在空间(步骤2)。然后,自编码器重建图像(步骤3),并重复此过程,直到输入图像(x)和输出图像(x*)之间的重构误差尽可能小为止。自编码器是GAE和VGAE背后的基本思想。

image.png

5.5.1 理解链接预测任务

链接预测是图学习中的一个常见问题,尤其是在我们对数据的知识不完全的情况下。这可能是因为图随时间变化,例如,当我们预期新客户将使用电子商务服务时,我们希望有一个模型能够在那个时刻提供最佳的产品推荐。另一方面,获取这些知识可能成本较高,例如,如果我们希望模型预测哪些药物组合会导致特定的疾病结果。最后,我们的数据可能包含错误信息或故意隐藏的细节,例如社交媒体平台上的假账户。链接预测使我们能够推断图中节点之间的关系。实质上,这意味着创建一个模型,预测节点何时以及如何连接,如图5.16所示。

image.png

对于链接预测,模型将以节点对作为输入,预测这些节点是否连接(是否应该被链接)。为了训练模型,我们还需要真实标签。我们通过隐藏图中的一部分链接来生成这些标签。这些隐藏的链接成为缺失的数据,我们需要学习推断它们,这些数据被称为负样本。然而,我们还需要一种方法来编码节点对之间的信息。通过使用图卷积自编码器(GAEs),这两个部分可以同时解决,因为自编码器既可以编码边的信息,也可以预测边是否存在。

5.5.2 内积解码器

内积解码器用于图数据,因为我们希望从特征数据的潜在表示中重建邻接矩阵。图卷积自编码器(GAE)学习如何在给定节点潜在表示的情况下重建图(推断边)。高维空间中的内积计算了两个位置之间的距离。我们使用内积,并通过sigmoid函数进行缩放,以获得节点之间是否存在边的概率。本质上,我们使用潜在空间中点之间的距离作为节点解码时是否会连接的概率。这使我们能够构建一个解码器,从潜在空间中采样并返回边是否存在的概率,即执行边预测,如图5.17所示。

image.png

内积解码器通过获取数据的潜在表示,并使用传递的边索引对数据应用内积来工作。然后,我们对这个值应用sigmoid函数,返回一个矩阵,其中每个值表示两个节点之间是否存在边的概率。

正则化潜在空间

简单来说,KL散度告诉我们如果我们在估计某个事件的概率时使用了错误的概率密度,结果会有多糟糕。假设我们有两个硬币,并想猜测一个已知是公平的硬币(我们知道它是公平的)与另一个硬币(我们不知道它是否公平)之间的匹配程度。我们试图用已知概率的硬币来预测未知概率硬币的概率。如果它是一个好的预测器(即未知硬币实际上是公平的),那么KL散度将为零。两个硬币的概率密度是相同的;然而,如果我们发现这个硬币是一个差的预测器,那么KL散度将会很大。这是因为这两个概率密度会相距很远。在图5.18中,我们可以明确看到这一点。我们正试图使用条件概率密度P(Z|X)来建模未知的概率密度Q(z)。当这些密度重叠时,KL散度将很低。

image.png

实际上,我们通过在损失中引入KL散度,将自编码器转换为变分图卷积自编码器(VGAE)。我们的目的是在最小化编码器和解码器之间的差异(如自编码器损失中所示)的同时,最小化编码器给出的概率分布与用于生成数据的“真实”分布之间的差异。这是通过将KL散度添加到损失中来完成的。对于许多标准的VGAE,这个损失函数为:

image.png

其中,(p||q) 表示概率p相对于概率q的散度。项m是潜在特征的均值,log(var)是方差的对数。当我们构建VGAE时,我们在损失函数中使用这个公式,确保前向传递同时将均值和方差返回给解码器。

过度压缩

我们已经讨论了如何通过消息传递传播节点和边的表示来获取节点信息。这些表示用于创建单个节点或边的嵌入,帮助引导模型执行一些特定任务。在本章中,我们讨论了如何构建一个模型,通过将消息传递层创建的所有嵌入传播到潜在空间中来构建潜在表示。这两者都执行了图数据的维度缩减和表示学习。

然而,图神经网络(GNN)在能使用的信息量方面有一个特定的限制。GNN遭遇了一种被称为“过度压缩”(over-squashing)的问题,这指的是信息在图中传播多个跳数(即消息传递)时,导致性能显著下降。这是因为每个节点从其邻域接收到的信息量,也称为其感受野,随着GNN层数的增加而呈指数级增长。当更多的信息通过这些层的消息传递进行聚合时,来自远距离节点的重要信号会被稀释,相比之下,来自较近节点的信息占主导地位。这使得节点表示变得更加相似或更为同质,最终会趋于相同的表示,这也被称为过度平滑(over-smoothing),我们在第4章中讨论过这一点。

实证研究表明,这种情况可能在仅有三到四层时就开始发生[7],如图5.19所示。这突显了GNN与其他深度学习架构之间的一个关键区别:我们很少想要构建一个具有许多层堆叠的深度模型。对于具有许多层的模型,通常还会引入其他方法以确保包括远程信息,例如跳跃连接或注意力机制。

image.png

在本章之前的示例中,我们讨论了使用图神经网络(GNN)进行药物发现。在这个例子中,我们考虑的是相对较小的图。然而,当图变得更大时,长程交互变得越来越重要。这在化学和生物学中尤其如此,其中图的极端节点可能对图的整体属性产生过大的影响。在化学的背景下,这可能是两个原子,它们是一个大分子的两端,决定了分子的整体属性,如其毒性。为了有效地建模问题,我们需要考虑的交互范围或信息流被称为问题半径。在设计GNN时,我们需要确保层数至少与问题半径一样大。

通常,有几种方法可以解决GNN的过度压缩问题:

  1. 确保不要堆叠太多的层。
  2. 在节点之间添加新的“虚拟”边,特别是那些相距很远或跳数很多的节点,或者引入一个与所有其他节点相连的单一节点,从而将问题半径缩小到2。
  3. 使用采样方法,如GraphSAGE,它从邻域中采样或引入跳跃连接,类似地跳过一些局部邻居。对于采样方法,平衡局部信息的丧失与长程信息的获取是很重要的。

所有这些方法都是高度问题特定的,在决定长程交互是否重要时,你应该仔细考虑图中节点之间的交互类型。例如,在下章中,我们讨论的是运动预测,其中头部与脚部之间的影响可能远小于头部与膝盖之间的影响。另一方面,本章描述的分子图可能会受到更远节点的较大影响。因此,解决过度压缩等问题的最重要部分是确保你对自己的问题和数据有扎实的理解。

总结

  • 判别模型学习如何区分数据类别,而生成模型则学习如何建模整个数据空间。
  • 生成模型通常用于执行维度降维。主成分分析(PCA)是一种线性降维方法。
  • 自编码器包含两个关键组件,即编码器和解码器,二者均由神经网络表示。自编码器学习如何将数据编码(压缩)为低维表示,并将其解码(解压缩)回原始数据。对于自编码器,低维表示称为潜在空间。
  • 变分自编码器(VAE)将自编码器扩展为在损失中加入正则化项。这个正则化项通常是Kullback-Leibler(KL)散度,用于衡量两个分布之间的差异——即学习到的潜在分布和先验分布。VAE的潜在空间更加结构化和连续,其中每个点代表一个概率密度,而不是一个固定的编码点。
  • 自编码器和变分自编码器(VAE)也可以应用于图。分别为图自编码器(GAE)和变分图自编码器(VGAE)。它们与典型的自编码器和VAE类似,但解码器通常是应用于边列表的内积运算。
  • GAE和VGAE对于边预测任务非常有用。它们可以帮助我们预测图中可能隐藏的边。