经典论文阅读-17:GraphSAGE

191 阅读14分钟

大图上的归纳表示学习

论文地址: [PDF] Inductive Representation Learning on Large Graphs | Semantic Scholar

摘要

大图中节点的低维嵌入已被证明在许多预测任务上非常有用,从内容推荐到识别蛋白质功能。然而,大多数已有方法需要在训练嵌入时输入图中所有已有节点,这些方法本质上属于直推式学习,不能泛化到未知节点。本文提出 GraphSAGE 模型,一种通用归纳式利用节点特征信息的框架,能有效生成未知节点嵌入。不再是对每个节点训练单独的嵌入,我们的模型基于节点局部邻居的采样与聚合来生成嵌入。我们算法性能优于效果最好的基准模型,这些基准模型在三个归纳式节点分类数据集上表现优异:我们在引文与 Reddit 文章数据集上做未知节点分类实验,并在蛋白质交互网络上展示算法的泛化能力。

简介

大图中节点低维嵌入已被证明在各种预测与图分析认为中非常有用。节点嵌入方法的基本思想是使用降维技术浓缩关于节点图邻居的高维信息到密度嵌入向量中。这些节点向量随后输入给下游机器学习系统,辅助完成如节点分类、链路预测、社区挖掘等任务。

然而,之前的研究只关注在单个静态图上嵌入节点,而许多真实世界中的应用需要快速嵌入未知节点或整个新的子图部分。这种归纳能力对高吞吐、产品级机器学习系统很关键,这些系统处理动态图,经常遇到未知节点。这种归纳式节点嵌入也有助于提升同一种特征的跨图泛化性。例如,人们可以在某个组织的蛋白质交互网络上训练嵌入生成器,然后将其用在新组织上生成节点嵌入。

与直推式相比,解决归纳式节点嵌入问题非常困难,因为泛化未知节点需要将新观测到的子图对齐到已训练好的节点嵌入空间。归纳式框架必须学着识别节点邻居的结构属性,既要考虑节点的局部角色,又要考虑其全局位置。

大多数已有生成节点嵌入的方法本质上是直推式的。这些方法的大多数直接通过基于矩阵因式分解的目标函数来优化节点嵌入,因为它们处理的是单个静态图。这些方法可被修改为归纳式处理,但修改后的计算复杂度通常会很高,需要更多的梯度下降轮次来优化模型。最近也有通过卷积运算学习图结构的方法,提供一种可靠的嵌入方法论。目前为止,图卷积网络 GCNs 只被用在静态图的直推式嵌入学习。本文我们将 GCNs 扩展到归纳式无监督学习任务上,并提出一种使用可训练聚合函数来泛化 GCN 的框架。

pic1.png

本文贡献是,我们提出一种通用框架,称为 GraphSAGE(SAmple and aggreGatE),做归纳式节点嵌入。不像那些基于矩阵分解到嵌入方法,我们利用节点特征来训练可泛化到未知节点的嵌入函数。通过在学习算法中结合节点特征,我们同时学习每个节点邻居的拓扑结构与节点特征在邻居中的分布。虽然我们重点关注特征丰富的图,我们的方法也可用于只使用结构特征的所有图,即这些图没有节点特征。

不再是给每个节点训练独立嵌入向量,我们训练一组聚合函数来学习节点局部邻居的聚合特征。每个聚合函数从给定节点的不同跳数(搜索深度)的邻居中聚合信息。在测试时或推理时,我们在训练好的系统中使用聚合函数对位置节点做处理,生成嵌入向量。遵循之前的生成节点嵌入的研究,我们设计一种无监督损失函数使 GraphSAGE 能在无需特定任务监督信息的情况下完成训练。我们也展示了 GraphSAGE 以全监督形式训练的场景。

我们在三个节点分类基准上评估算法,测试 GraphSAGE 能生成有用的未知节点嵌入。我们在测试时使用了两个基于引文网络和 Reddit 文章网络的动态文档网络,和一个用于多图泛化实验的蛋白质交互网络。通过使用这些基准数据集,我们展示了我们的模型可以高效生成未知节点表示,且性能显著优于相关基准模型。在各种领域里,我们的监督方法与只使用节点特征的方法相比,平均提升 51% 的分类 F1 值。GraphSAGE 性能与最强的直推式基准模型一致,但基准模型在处理未知节点时要多耗费 100 倍时间。我们也展示了新型聚合架构的性能提升。最后,我们通过理论分析探索该模型的表达能力,GraphSAGE 能够学习节点在图中的结构性角色信息,尽管其本质上是基于特征的。

相关研究

我们的算法在概念上涉及之前的节点嵌入方法、通用有监督图学习方法、最近新的图结构卷积神经网络方法。

基于分解的嵌入方法

有大量的节点嵌入方法使用随机游走和矩阵分解目标函数来学习低维嵌入。这些方法与更经典的谱聚类、多维缩放、PageRank 算法关系很近。因为这些方法直接对单个节点训练嵌入,其本质上是直推式的,至少需要耗时的额外训练才能处理未知节点。此外,许多这些方法的目标函数具有嵌入垂直变换的不变性,意味着嵌入空间不能天然泛化到不同的图上,需要通过重新训练来迁移。这种方法里的一个显著例外是 Yang 等人提出的 Planetoid-I 算法,是一种归纳式半监督学习的嵌入方法。然而,该模型在推理时不使用图结构信息,只在训练时将图结构信息用于正则化。不像这些方法,我们的方法使用节点特征信息训练模型,可产生未知节点的嵌入。

图上监督学习

除了节点嵌入方法,还有许多关于图结构监督学习的文献,包括大量核方法,从各种图核里衍生出图的特征向量。也有许多关于图结构监督学习的神经网络方法。我们的方法从概念上受这些方法的启发。然而,这些方法尝试对整个图做分类,而本文关注生成单个节点的可用表示。

图卷积网络

近几年,一些图卷积网络架构被提出,然而这些方法中的大多数不能扩展到大规模图上,或只是设计给全图分类使用。而我们的方法与图卷积紧密相关。原始图卷积被设计来做直推式半监督学习,算法在训练时需要全图拉普拉斯矩阵。我们方法的简单变体可视为 GCN 框架做归纳式学习的扩展。

GraphSAGE

我们方法背后的关键思想是,学习如何聚合节点局部邻居的特征信息。首先介绍 GraphSAGE 嵌入生成算法,假设模型参数已完成训练,描述节点嵌入生成过程。然后介绍如何使用标准随机梯度下降和反向传播训练 GraphSAGE 模型参数。

嵌入生成算法

本节介绍嵌入生成或前向传播过程,假设模型参数已完成训练并固定不变。特别地,假设已经学习到 K 个聚合器的参数,这些聚合器能从节点局部邻居中聚合信息,带有参数 WkW^k ,用于在模型不同层之间传递信息。

pic2.png

上述算法描述了嵌入生成过程,其中图 G=(V,E)G=(V,E) 中作为模型输入的节点特征为 xv,vVx_v,\forall v\in V 。以下表述如何在最小批中生成嵌入表示。最外层循环的每一步处理如下,k 表示当前步数,hkh^k 表示当前步中的节点表示:首先,每个节点 vVv\in V 聚合其之间邻居的节点表示 huk1,uN(v)h_u^{k-1},\forall u \in N(v) ,得到向量 hN(v)kh_{N(v)}^k 。注意该聚合步骤取决于上一轮迭代后的节点表示,k = 0 时节点表示为输入的节点特征 xx。聚合邻居特征向量后,GraphSAGE 将其连接到节点当前表示 hvk1h_v^{k-1} 之后,然后将连接结果输入给带激活函数 σ\sigma 的全连接层,得到下一层使用的节点表示。为了标记方便,我们用 K 表示输出深度,即最后一层的节点输出 zv=hvKz_v=h_v^K 。对邻居表示的聚会可使用下文讨论的各种聚合器架构。

为了将上述算法扩展到最小批过程,给定一组输入节点,首先前向采样需要的邻居集合(深度为 K),然后运行内部循环时不再是迭代所有节点,只计算这组输入中满足递归要求的节点表示。

与 WL 同构性测试的关系

GraphSAGE 算法概念上受经典图同构性测试算法影响。如果在算法中设置 K=|V|,权重矩阵为单位矩阵,使用适当哈希函数作为聚合函数,那么算法就是某种 WL 同构性测试的实例,也称为原生节点提炼 naive vertex refinement。如果算法对两个子图的输出表示 {zv,vV}\{z_v,\forall v\in V\} 相等,那么可以说这两个子图经 WL 测试后判定为同构图。该测试在有些情况下会失败,但适用于一大类图。GraphSAGE 是 WL 测试的持续估计,使用训练好的神经网络聚合器替换哈希函数。当然,我们使用 GraphSAGE 是为了生成节点的有用表示,而不是测试同构性。

邻居定义

本文中,我们统一采样固定数量的邻居,而不是用全部邻居,以保持每批计算量一致。换句话说,使用重载的标记 N(v) N(v) 表示从集合 {uV:(u,v)E}\{u\in V:(u,v)\in E\} 中均匀采样固定数量的点,而且在不同迭代轮次 k 中使用不同的均匀采样。不经过此采样处理的话,每批数据在训练时的内存与期望时间是不可预测的,最差情况为 O(V)O(|V|) 。相比之下,使用采样后每批训练的空间和时间复杂度固定为 O(i=1KSi)O(\prod_{i=1}^KS_i) ,S 为用户指定的采样数。通过实践,我们发现模型能达到最佳效果的参数为 K=2 且 S1S2500S_1\cdot S_2\le 500

学习 GraphSAGE 参数

为了在完全无监督情况下学习有用可预测的表示,我们使用基于图的损失函数来输出表示zu,uVz_u,\forall u \in V,然后使用梯度下降法微调权重矩阵 Wk,k{1,..,K}W^k,\forall k \in \{1,..,K\} 和聚合函数的参数。基于图的损失函数鼓励相邻节点具有相似表示,强制提升分离节点的表示距离。

JG(zu)=log(σ(zuTzv))QEvnPn(v)log(σ(zuTzvn))J_G(z_u)=-\log(\sigma(z_u^Tz_v))-Q\cdot\mathbb E_{v_n\sim P_n(v)}\log(\sigma(-z_u^Tz_{v_n}))

其中 vv 表示与节点 uu 在固定随机游走中临近的节点,σ\sigma 为 sigmoid 函数,PnP_n 为负采样分布,QQ 为负采样样本数量。重要的是,不像以前的嵌入方法,我们输入给损失函数的节点表示 zuz_u 来自包含节点局部邻居的特征,而不是给每个节点训练独立的嵌入。

该无监督方法模拟了将节点特征提供给下游机器学习任务的情况,在特定任务中,可以将无监督损失函数替换为任务特定的目标函数。

聚合架构

不像在 N 维格子上的机器学习,如句子、图片、3D 形状,节点的邻居天然无序,因此聚合函数必须能够处理无序的向量集合。理想情况下,聚合函数具有对称性,同时具有可训练性和维持高维表示能力。聚合函数的对称性保证了神经网络模型可用于任何顺序的节点邻居特征集合。我们测试过以下三种聚合函数。

平均聚合

首先是平均聚合,简单求取各邻居向量平均值。平均聚合基本等价于 GCN 中的卷积传播规则。特别地,我们可以替换算法中的聚合过程,衍生一种 GCN 方法的变体。

hvkσ(WMEAN({hvk1}{huk1,uN(v)}))h_v^k \gets \sigma (W \cdot MEAN(\{h_v^{k-1}\} \cup \{h_u^{k-1},\forall u \in N(v)\}))

我们可称其为基于平均聚合的卷积,因为它是一种粗化、线性估计的局部谱卷积。这种卷积聚合和我们方法中的聚合器的重要区别是,卷积聚合不做各层表示的连接操作。GraphSAGE 算法将聚合后的邻居向量与当前节点表示连接在一起,可视为不同搜索深度之间的某种近路连接,通常会带来显著性能提升。

LSTM 聚合

我们也试过更复杂的 LSTM 架构的聚合器。与平均聚合相比,LSTM 聚合有更大的表示能力。然而,要注意到 LSTM 本质上不具有对称性,其处理有序形式的输入。我们使用 LSTM 处理无序集合时要对节点邻居做随机打乱。

池化聚合

我们最后检验的聚合器为池化聚合,既有对称性也可训练。在池化处理中,每个邻居向量独立输入到全连接神经网络,通过转换后,使用元素级最大池化运算做邻居向量集合的聚合。

AGGRAGATEkpool=max({σ(Wpoolhuik+b),uiN(v)})AGGRAGATE_k^{pool}=max(\{\sigma(W_{pool}h_{u_i}^k+b), \forall u_i \in N(v)\})

原理上,在最大池化之前可以使用任意的深度多层感知器,但本文只用单层结构。该方法灵感来自神经网络学习通用点集的相关研究。直觉地讲,多层感知器可被认为是计算邻居集合中节点表示特征的一组函数。通过使用最大池化,模型可高效捕获邻居集合的不同方面特征。要注意,在原理上任何对称向量函数都可用来替换 max 操作。我们通过实验发现最大池化与平均池化没有显著差别,因此在后续实验中使用最大池化。

实验

理论分析

本节,我们探索 GraphSAGE 的表达能力,了解模型如何在基于特征的情况下学习网络结构。通过案例,我们考虑 GraphSAGE 能否学习预测节点的聚类系数,即节点一跳关系内的封闭三角形数量。我们的算法能在任意度上估计聚类系数。

pic3.png

上述定理说明,如果节点的特征具有区分度,那么存在一组 GraphSAGE 的参数,能使模型估计给定节点的聚类系数。该定理同时表示了,GraphSAGE 可以学习局部图结构信息。

总结

我们提出一种新的节点嵌入方法,能高效生成未知节点的表示。GraphSAGE 的性能与最优模型一致,但运行效率要好得多。通过理论分析展示了模型如何学习局部图结构。该模型还有许多扩展和潜在提升,可将 GraphSAGE 模型扩展到结合有向与多模态图。一个非常有趣的研究方向是探索非均匀邻居采样方法,甚至在模型中学习采样函数。