Motif-based Graph Self-Supervised Learning for Molecular Property Prediction论文笔记

97 阅读3分钟

Motif-based Graph Self-Supervised Learning for Molecular Property Prediction

image-20241017155543696

任务:分子性质的预测

会议:NIPS 2021

Motivation

  • GNNs监督训练框架的问题:数据标签的获取代价高,且在小批量数据上训练容易过拟合,不能泛化到其他数据
  • GNNs自监督预训练框架的问题:只关注节点级或图级任务,这些方法不能捕获子图或模体中的丰富信息。例如,官能团(分子图中经常出现的子图)通常携带有关分子性质的指示性信息。

Contribution

提出了一种基于模体的图自监督学习框架,分为三个部分:

  • 首先,为了从分子图中提取模体,论文设计了一种基于BRICS算法的分子分解方法,并通过额外的规则控制模体词汇表的大小
  • 设计了一个通用的、基于模体的生成式预训练框架(Motif-based Graph Self-supervised Learning,MGSSL),在该框架中,GNNs需要进行拓扑和标签预测(类似GPT)。该生成框架可以通过广度优先或深度优先两种不同的方式实现。
  • 为了考虑分子图中的多尺度信息,论文引入了一种多级自监督预训练方法

现有的自监督预训练方法可以分为两类:

  • 对比方法:拉近同一个图不同视图的表示,而不同图的视图的表示则推远
  • 预测方法:利用数据的内在属性来构建预测任务。例如预测原子的上下文(context);或者屏蔽节点/边的属性,让模型预测;或者进行图的重构(生成式)

然而,现有的这些自监督学习方法都未能利用图中模体丰富的语义信息(论文的方法属于预测方法这一个大类,只是加入了模体树的拓扑和标签预测)


化学启发的分子分解

给定一个分子数据集,该方法的第一步是将分子分解为若干模体。基于这些模体,一个分子图可以被转换为模体树结构,其中每个节点代表一个模体,边表示基元之间的空间关系。我们选择用树结构表示模体之间的关系,因为树结构有利于后续的模体生成任务

形式化地,给定一个分子图 G=(V,E)G = (V, E),模体树 T(G)=(V,E,X)\mathcal{T}(G)=(\mathcal{V}^′,\mathcal{E}^′,\mathcal{X}^′) 是一个连通的标记树,其中的V\mathcal{V}^′表示模体集,边集 E\mathcal{E}^′表示模体之间的连接, X\mathcal{X}^′是诱导的模体词汇表。每个模体MiM_iGG 的一个子图。分解一个图的方式有很多种,但设计分子分解方法需要达到以下目标:

  1. 在模体树 T(G)\mathcal{T}(G) 中,所有模体的并集应等于原分子图 GG的全部节点和边;
  2. 各模体之间不应有交集,即不同模体之间的节点和边应互不重叠;
  3. 所生成的模体应具有语义意义,类似于化学中的功能基团;
  4. 模体的出现频率应该足够高,以便自监督学习中,GNN 能够从中学习到可以泛化到下游任务的语义信息。

分子分解的总体流程,包括三步:

  1. 使用 BRICS(Breaking of Retrosynthetically Interesting Chemical Substructures)算法进行初步分解。BRICS 使用化学反应中的规则打断分子中具有策略意义的键,并在分解位点的末端添加“虚拟原子”以标记片段间的连接位置;
  2. 进一步分解 BRICS 的输出片段,以减少冗余;
  3. 构建模体树。

树是无环的连通图,由于分子图本身是连通的,而其中的环(比如苯环、呋喃)都可被看作是一个模体,所以可以确保分子图经过模体的转换后是一个树

image-20241023090644733

尽管 BRICS 是一种有效的分子分解方法,但它在生成大分子片段时存在一定的局限性,尤其是由于分子结构的组合爆炸,生成了大量结构相似但有细微变化的基元(如具有不同卤素原子的呋喃环)。为了解决这些问题,作者在 BRICS 之后引入了两条额外的分解规则:

  1. 当一个键的两端分别连接环状和非环状原子时,将其断开;
  2. 将邻接原子数超过 3 的非环状原子视为新模体,并断开它们的相邻键。

这些额外规则有效地减少了模体词汇的规模,并提高了模体在整个数据集中的出现频率。最终,通过该分子分解方法,可以建立一个适中的模体词汇集,为后续的 GNN 预训练提供支持。

论文的分子分解方法依赖于BRICS这个化学领域的算法做初步分解,之后引入两条额外的规则做进一步分解以减小模体词汇集的大小。但问题是这里的BRICS并不能推广到其他领域的图数据集,比如图相似度计算中常用的IMDB(演员关系网络)和Linux(内核程序调用图)


模体生成

论文的整体框架是生成式的自监督预训练,预训练的过程就是【给定已生成的模体树】,【预测下一个模体和边】,将生成训练数据中相应模体树的概率最大化,从而让网络学习到图模体的数据分布,使其经过微调后可以推广到下游任务。有点类似于GPT(Generative Pre-trained Transformer)的训练过程

GPT是在大量的无标注文本数据上进行训练,任务是给定一个序列的前半部分,模型被要求生成后续的单词

给定一个分子图 G=(V,E)G = (V, E) 和一个GNN模型 fθf_{\theta},首先将分子图转换为模体树 T(G)=(V,E,X)\mathcal{T}(G) = (\mathcal{V}, \mathcal{E}, \mathcal{X})。然后可以使用GNN模型对该模体树的可能性进行建模,表示为 p(T(G);θ)p(\mathcal{T}(G); \theta),表示模体是如何被标记和连接的。一般来说,我们的方法旨在通过最大化模体树的可能性来预训练GNN模型,即:

θ=argmaxθp(T(G);θ)\theta^* = \arg \max_{\theta} p(\mathcal{T}(G); \theta)

为了建模模体树的可能性,设计了特殊的拓扑和模体标签预测头,并与fθf_{\theta}一起优化。预训练后,只有GNN模型 fθf_{\theta}被转移到下游任务。

大多数现有的图生成工作都遵循自回归方式来分解概率目标,即本文中的 p(T(G);θ)p(\mathcal{T}(G); \theta)。对于每个分子图,它们将其分解为一系列生成步骤。类似地,论文交替地添加一个新模体,以及连接该模体与现有部分模体树的边。使用排列向量 π\pi 来确定模体的顺序,其中 iπi^{\pi} 表示排列 π\pi 中第 ii 个位置的模体ID。因此,概率 p(T(G);θ)p(\mathcal{T}(G); \theta) 相当于所有可能排列的期望可能性,即:

p(T(G);θ)=Eπ[pθ(Vπ,Eπ)]p(\mathcal{T}(G); \theta) = E_{\pi} [p_{\theta}(\mathcal{V}^{\pi}, \mathcal{E}^{\pi})]

其中,Vπ\mathcal{V}^{\pi}表示经过排列的模体标签,Eπ\mathcal{E}^{\pi}表示模体之间的边。

论文的形式化允许多种顺序。为简化起见,假设任何模体顺序 π\pi 的概率相等,并且在接下来的部分中,我们在说明生成过程时忽略了脚注 π\pi。给定一个排列顺序,生成模体树 T(G)\mathcal{T}(G) 的概率可以分解为:

logpθ(V,E)=i=1Vlogpθ(Vi,EiV<i,E<i)\log p_{\theta}(\mathcal{V}, \mathcal{E}) = \sum_{i=1}^{|\mathcal{V}|} \log p_{\theta}(\mathcal{V}_i, \mathcal{E}_i | \mathcal{V}_{<i}, \mathcal{E}_{<i})

在每一步 ii,我们使用之前生成的所有模体 V<i\mathcal{V}_{<i} 和它们的结构 E<i\mathcal{E}_{<i} 来生成新的模体 Vi\mathcal{V}_i 及其与现有模体的连接 Ei\mathcal{E}_i

接下来的问题是如何选择高效的生成顺序,以及如何建模条件概率 logpθ(Vi,EiV<i,E<i)\log p_{\theta}(\mathcal{V}_i, \mathcal{E}_i | \mathcal{V}_{<i}, \mathcal{E}_{<i})

生成顺序

论文提出了广度优先搜索(BFS)和深度优先搜索(DFS)两种顺序。

要从头生成一个模体树,首先需要选择模体树的根。在我们的实验中,我们简单地选择在规范顺序中第一个原子所在的模体作为根节点。然后,MGSSL以DFS或BFS顺序生成模体。

DFS

在DFS顺序中,对于每个访问到的模体,MGSSL首先进行拓扑预测:该节点是否有子节点需要生成。如果生成了新的子模体节点,我们预测其标签并递归这一过程。当没有更多的子节点需要生成时,MGSSL会回溯

BFS

对于BFS顺序,MGSSL按层次生成模体节点。对于第 kk 层的模体节点,MGSSL进行拓扑预测和标签预测。如果第 kk 层的所有子节点都已生成,MGSSL将移动到下一层。

我们注意到,在BFS和DFS中,模体节点的顺序并不唯一,因为同级节点之间的顺序是模糊的。在实验中,我们在一种顺序下进行预训练,并将此问题及其他可能的生成顺序留待未来研究。

image-20241023162420598

在每个时间步,模体节点从其他已生成的模体接收信息以进行(拓扑和标签)预测。当模体树逐步构建时,信息通过消息向量 hi,j\mathrm{h}_{i,j}传播。令 E^t\hat{\mathcal{E}}_t 是时间 tt 的消息集,模型在时间 tt 访问模体 iixix_i 表示模体 ii 的嵌入,它可以通过对模体 ii 中原子的嵌入进行池化获得。消息 hi,j\mathrm{h}_{i,j} 通过先前的消息进行更新:

hi,j=GRU(xi,{hk,i}(k,i)E^t,kj)\mathrm{h}_{i,j} = \text{GRU}(x_i, \{\mathrm{h}_{k,i}\}_{(k,i) \in \hat{\mathcal{E}}_t, k \neq j})

其中,GRU是门控循环单元,应用于模体树的消息传递。

GRU的输入有两个部分,即当前时间步的输入和上一个时间步的隐藏状态;输出是当前时间步的隐藏状态。在MGSSL中:

当前时间步的输入是模体 ii 的嵌入 xix_i,上一时间步的隐藏状态是之前所有传递给模体 ii 的消息向量求和:

si,j=(k,i)E^t,kj hk,is_{i, j} =\sum_{(k, i) \in \hat{\mathcal{E}}_{t}, k \neq j} \mathrm{~h}_{k, i}

更新门控制了从上一时间步的隐藏状态中保留多少信息到当前时间步:

zi,j=σ(Wzxi+Uzsi,j+bz)z_{i, j} =\sigma\left(\mathrm{W}^{z} x_{i}+\mathrm{U}^{z} s_{i, j}+b^{z}\right)

重置门控制了将多少上一步的隐藏状态信息引入到新的候选隐藏状态的计算中:

rk,i=σ(Wrxi+Ur hk,i+br)r_{k, i} =\sigma\left(\mathrm{W}^{r} x_{i}+\mathrm{U}^{r} \mathrm{~h}_{k, i}+b^{r}\right)

在更新当前时间步的隐藏状态时,GRU首先会计算候选的隐藏状态。候选隐藏状态结合了当前时间步的输入和上一时间步隐藏状态,并由重置门控制

h~i,j=tanh(Wxi+Uk=N(i)\jrk,ihk,i)\tilde{\mathrm{h}}_{i, j} =\tanh \left(\mathrm{W} x_{i}+U \sum_{k=\mathcal{N}(i) \backslash j} r_{k, i} \odot \mathrm{h}_{k, i}\right)

最终的隐藏状态基于更新门对候选隐藏状态和上一时间步隐藏状态的加权和:

hi,j=(1zij)sij+zij h~i,j\mathrm{h}_{i, j} =\left(1-z_{i j}\right) \odot s_{i j}+z_{i j} \odot \tilde{\mathrm{~h}}_{i, j}
拓扑预测

当MGSSL访问模体 ii 时,它需要对该节点是否有子节点生成进行二元预测。通过一个包含隐藏层的网络和sigmoid函数来计算概率,该网络考虑了传递给该模体的消息和模体嵌入:

pt=σ(Udτ(W1dxi+W2d(k,i)E^thk,i))p_t = \sigma \left( U^d \cdot \tau(W_1^d x_i + W_2^d \sum_{(k,i) \in \hat{\mathcal{E}}_t} \mathrm{h}_{k,i}) \right)
模体标签预测

当拓扑预测模体 ii 具有子节点 jj 时,我们通过以下公式预测子节点 jj 的模体标签:

qj=softmax(Ulτ(Wlhi,j))q_j = \text{softmax}(U^l \tau(W^l \mathrm{h}_{i,j}))

其中,qjq_j 是对模体词汇集X\mathcal{X}的概率分布,ll 是隐藏层维度。设 p^t{0,1}\hat{p}_t \in \{0,1\}q^j\hat{q}_j 分别为拓扑预测和模体标签预测的真实值,模体生成的损失是拓扑和模体标签预测的交叉熵损失之和:

Lmotif=tLtopo(pt,p^t)+jLpred(qj,q^j)\mathcal{L}_{\text{motif}} = \sum_t \mathcal{L}_{\text{topo}}(p_t, \hat{p}_t) + \sum_j \mathcal{L}_{\text{pred}}(q_j, \hat{q}_j)

在优化过程中,最小化上述损失函数相当于最大化上面的条件概率 。在训练过程中,每个步骤后,预测的拓扑和模体标签会被替换为其真实值,以确保MGSSL基于正确的历史信息进行预测。


多层次自监督预训练

为了捕获分子中的多尺度信息,MGSSL 被设计为一个包含原子级任务和模体级任务的分层框架。对于原子级预训练,我们利用属性掩码(attribute masking)来让GNNs首先学习节点/边属性的规律性。在属性掩码中,随机采样的节点和键的属性(如原子序号、键类型)被替换为特殊的掩码标记。接下来,我们应用GNNs获取相应的节点/边嵌入(边的嵌入可以通过该边端点节点嵌入的组合得到)。最后,在嵌入的基础上通过一个全连接层来预测节点/边的属性。交叉熵预测损失分别记为 Latom\mathcal{L}_{atom}Lbond\mathcal{L}_{bond}

image-20241023202913700

为了避免在顺序预训练中的灾难性遗忘,我们统一了多层次任务,并在预训练过程中最小化混合损失:

Lssl=λ1Lmotif+λ2Latom+λ3Lbond\mathcal{L}_{ssl} = \lambda_1 \mathcal{L}_{motif} + \lambda_2 \mathcal{L}_{atom} + \lambda_3 \mathcal{L}_{bond}

其中,λi\lambda_i 是损失的权重。

以DFS为例:

image-20241023212231658