本章内容包括:
- 理解注意力机制以及它是如何应用于图注意力网络(GAT)的
- 了解何时在PyTorch Geometric中使用GAT和GATv2层
- 通过NeighborLoader类使用小批量处理
- 在垃圾邮件检测问题中实现和应用图注意力网络层
在本章中,我们通过研究一种特殊的卷积图神经网络(卷积GNN)架构变体——图注意力网络(GAT),扩展了我们对卷积图神经网络的讨论。虽然这些GNN在前一章中介绍了卷积操作,但它们通过引入注意力机制扩展了这个思想,以便在学习过程中突出重要的节点。与传统的卷积GNN不同,传统的卷积GNN对所有节点给予相同的权重,而注意力机制使得GAT能够学习在训练过程中需要特别关注的方面。
与卷积一样,注意力机制在深度学习中被广泛使用,尤其是在GNN之外的应用。依赖于注意力机制的架构(特别是变换器)在解决自然语言处理问题方面取得了巨大的成功,现已主导了该领域。目前尚不清楚注意力机制在图结构领域是否会产生类似的效果。
当处理某些节点比图结构所暗示的重要性更高的领域时,GAT尤其突出。有时,在图中,可能会有一个高度节点,它对其他节点具有过度的重要性,而普通的消息传递(在上一章中介绍)可能会通过该节点的许多邻居捕捉到它的意义。然而,有时一个节点可能对整个图的影响很大,即便它的度数与其他节点相似。典型的例子包括社交网络,在其中,某些成员对信息或新闻的传播具有更大的影响力;欺诈检测,在其中,一小部分行为者和交易引发了欺诈行为;以及异常检测,在其中,少数人的行为或事件会偏离常规。GAT尤其适用于这些类型的问题。
在本章中,我们将在欺诈检测领域应用GAT。在我们的案例中,我们将检测来自Yelp网站的虚假客户评论。为此,我们使用了一个用户评论网络,该网络来自一个包含芝加哥地区酒店和餐厅的Yelp评论数据集。
在介绍问题和数据集之后,我们首先训练一个不使用图结构的基准模型,然后将两种版本的GAT模型应用于该问题。最后,我们讨论类别不平衡及其解决方法。
代码片段将用于解释整个过程,但大部分代码和注释可以在代码仓库中找到。和前几章一样,理论部分将在本章末尾的4.5节进行深入探讨。
注:本章的代码可以在GitHub仓库中以笔记本形式找到(mng.bz/JYoP)。Colab链接和本章数据也可以在同一位置访问。
4.1 垃圾邮件和虚假评论检测
在面向消费者的网站和电子商务平台,如Yelp、Amazon和Google Business Reviews等,用户生成的评论和评分通常会伴随产品或服务的展示和描述。在美国,超过90%的成年人在做出购买决策时依赖并信任这些评论和评分【3】。与此同时,这些评论中有许多是虚假的。Capital One估计,2024年约有30%的在线评论并不真实【5】。在本章中,我们将训练模型来检测虚假评论。
垃圾邮件或虚假评论检测一直是机器学习和自然语言处理(NLP)领域的一个热门研究方向。因此,多个主流消费网站和平台的数据集都已经公开。在本章中,我们将使用来自Yelp.com的评论数据集,Yelp是一个专注于消费者服务的用户评论和评分平台。在Yelp.com上,用户可以查找附近的本地商家,并浏览有关商家的基本信息和用户写的反馈。Yelp使用内部开发的工具和模型来根据评论的可信度进行筛选。我们将用来解决该问题的过程如图4.1所示。
首先,我们将使用非GNN模型和表格数据建立基准模型:逻辑回归、XGBoost和scikit-learn的多层感知机(MLP)。然后,我们将引入图卷积网络(GCN)和GAT,应用到问题中,并加入图结构数据。
这个虚假评论问题可以作为一个节点分类问题来解决。我们将使用GAT对Yelp评论进行节点分类,从中筛选出虚假评论和真实评论。这个分类是二分类的:“垃圾邮件”或“非垃圾邮件”。
我们预计,图结构数据和注意力机制将使基于注意力的GNN模型具有优势。在本章中,我们将遵循以下过程:
- 加载和预处理数据集
- 定义基准模型及结果
- 实现GAT解决方案并与基准结果进行比较
4.2 探索评论垃圾邮件数据集
我们的数据集来源于更广泛的Yelp评论数据集,聚焦于芝加哥地区的酒店和餐厅评论。数据已经过预处理,具有图结构。这意味着我们将使用一个专门版本的Yelp多关系数据集,它的特点是图结构,并且专注于来自芝加哥多个酒店和餐厅的消费者评论。Yelp多关系数据集是从Yelp评论数据集派生并处理成图的。该数据集包含以下内容(数据集的最终版本总结见表4.1):
- 45,954个节点 — 每个节点代表一条单独的评论,其中14.5%被标记为可能的虚假评论,并且由机器人创建,以操纵评论内容。
- 预处理的节点特征 — 节点具有32个特征,已标准化以便机器学习算法使用。
- 3,892,933条边 — 边连接了具有共同作者或评论相同商家的评论。尽管原始数据集中有多种关系类型的边,我们使用了具有同质边的版本,以便更简化的分析。
- 无用户或商家ID — 标识ID已被移除。
表4.1:Yelp多关系数据集概览
| 项目 | 内容 |
|---|---|
| 芝加哥市Yelp评论数据集处理成图 | 节点特征基于评论文本和用户数据 |
| 节点数量(评论) | 45,954 |
| 过滤(虚假)节点 | 14.5% |
| 节点特征 | 32 |
| 总边数(在分析中假定边是同质的) | 3,846,979 |
| 共同作者的评论数 | 49,315 |
| 同月评论同一商家的评论数 | 73,616 |
| 同一商家且共享相同评分的评论数 | 3,402,743 |
接下来,表4.2显示了来自该数据集的一些评论样本,按星级评分系统降序排列。
表4.2:来自YelpChi数据集的某餐厅评论样本(按评分降序排列,5为最高)
| 评分(1-5) | 日期 | 评论* |
|---|---|---|
| 5 | 7/7/08 | 完美。Snack已经成为我最喜欢的午后/早晚餐地点。一定要试试黄油豆!!! |
| 4 | 7/1/13 | 上周五为15人订了Snack的午餐。准时到达,食物没缺,味道很好。我已将其添加到公司常规午餐清单中,因为大家都很喜欢。 |
| 3 | 12/8/14 | Snack的食物是几道受欢迎的希腊菜。开胃菜盘和希腊沙拉不错。主菜让人失望。这里只有4-5张桌子,所以有时很难找到座位。 |
| 2 | 9/10/13 | 一直想试试这个地方——朋友推荐的。吃了金枪鱼三明治…还行,但之后我非常难受。此外,鼠尾草茶还不错。 |
| 1 | 8/12/12 | 服务平平,菠菜饼湿软且温吞,黄瓜沙拉是两天前的。去Local吧! |
*这些评论中的拼写、语法和标点未作更正。
4.2.1 解释节点特征
本数据集的亮点是其节点特征。这些特征是从可用的元数据中提取的,如评分、时间戳和评论文本。它们被划分为以下几类:
- 评论文本的特征
- 评论者的特征
- 被评论商家的特征
这些特征进一步分为行为特征和文本特征:
- 行为特征 关注评论者的行为和行动模式。
- 文本特征 基于评论中找到的文本内容。
这些特征的计算过程是由Rayana和Akoglu【7】以及Dou【9】开发的。Dou在使用Rayana和Akoglu的原始公式的基础上,预处理并标准化了我们在本示例中使用的特征数据。特征的总结如图4.2所示。(有关定义和计算方法的更多细节,请参见原始论文【8】。)以下是这些节点特征的总结:
评论者和商家特征:
行为特征:
- 每日写评论的最大数量(MNR) — 高值通常表示垃圾评论。
- 正面评论的比例(4-5星)(PR) — 高值通常表示垃圾评论。
- 负面评论的比例(1-2星)(NR) — 高值通常表示垃圾评论。
- 平均评分偏差(avgRD) — 高值通常表示垃圾评论。
- 加权评分偏差(WRD) — 高值通常表示垃圾评论。
- 爆发性(BST) — 具体指用户首次和最后一次评论之间的时间框架。高值通常表示垃圾评论。
- 评分分布的熵(ERD) — 低值通常表示垃圾评论。
- 时间间隔熵(ETG) — 低值可能是垃圾评论的指示。
基于文本的特征:
- 评论的平均长度(RL) — 低值通常表示垃圾评论。
- 基于袋装二元组方法计算的平均/最大内容相似度(ACS,MCS) — 高值通常表示垃圾评论。
评论特征:
行为特征:
- 在所有产品评论中的排名顺序 — 低值通常表示垃圾评论。
- 与产品平均评分的绝对评分偏差(RD) — 高偏差通常表示可疑。
- 评分极端性(EXT) — 高评分(4-5星)通常被视为垃圾评论。
- 评论评分的阈值偏差(DEV) — 高偏差通常表示可疑。
- 早期时间框架(ETF) — 过早出现的评论通常表示可疑。
- 单一评论者检测(ISR) — 如果评论是某个用户唯一的评论,则被标记为可疑。
基于文本的特征:
- 全部大写字母的百分比(PCW) — 高值通常表示可疑。
- 大写字母的百分比(PC) — 高值通常表示可疑。
- 评论的字数长度 — 低值通常表示可疑。
- 第一人称代词的比例,如“I”,“my”(PP1) — 低值通常表示可疑。
- 感叹句的比例(RES) — 高值通常表示可疑。
- 主观词汇的比例 — 由sentiWordNet(SW)检测 — 高值通常表示可疑。
- 客观词汇的比例 — 由sentiWordNet(OW)检测 — 低值通常表示可疑。
- 评论频率 — 通过局部敏感哈希(F)近似 — 高值通常表示可疑。
- 基于单字和双字的描述长度(DLu,DLb) — 低值通常表示可疑。
图4.2给出了特征集的总结。
4.2.1 解释节点特征
这个数据集的亮点之一是其节点特征。这些特征从可用的元数据中提取,如评分、时间戳和评论文本。它们被分为以下几类:
- 评论文本的特征
- 评论者的特征
- 被评论商家的特征
这些特征进一步分为行为特征和文本特征:
- 行为特征 关注评论者的行为和行动模式。
- 文本特征 基于评论中找到的文本内容。
这些特征的计算过程由Rayana和Akoglu【7】以及Dou【9】开发。Dou在Rayana和Akoglu的原始公式的基础上,对特征数据进行了预处理和标准化。特征的总结如图4.2所示。(有关定义和计算方法的更多细节,请参见原始论文【8】。)以下是这些节点特征的总结:
评论者和商家特征:
行为特征:
- 每日写评论的最大数量(MNR) — 高值通常表示垃圾评论。
- 正面评论的比例(4-5星)(PR) — 高值通常表示垃圾评论。
- 负面评论的比例(1-2星)(NR) — 高值通常表示垃圾评论。
- 平均评分偏差(avgRD) — 高值通常表示垃圾评论。
- 加权评分偏差(WRD) — 高值通常表示垃圾评论。
- 爆发性(BST) — 具体指用户首次和最后一次评论之间的时间框架。高值通常表示垃圾评论。
- 评分分布的熵(ERD) — 低值通常表示垃圾评论。
- 时间间隔熵(ETG) — 低值可能是垃圾评论的指示。
基于文本的特征:
- 评论的平均长度(RL) — 低值通常表示垃圾评论。
- 基于袋装二元组方法计算的平均/最大内容相似度(ACS,MCS) — 高值通常表示垃圾评论。
评论特征:
行为特征:
- 在所有产品评论中的排名顺序 — 低值通常表示垃圾评论。
- 与产品平均评分的绝对评分偏差(RD) — 高偏差通常表示可疑。
- 评分极端性(EXT) — 高评分(4-5星)通常被视为垃圾评论。
- 评论评分的阈值偏差(DEV) — 高偏差通常表示可疑。
- 早期时间框架(ETF) — 过早出现的评论通常表示可疑。
- 单一评论者检测(ISR) — 如果评论是某个用户唯一的评论,则被标记为可疑。
基于文本的特征:
- 全部大写字母的百分比(PCW) — 高值通常表示可疑。
- 大写字母的百分比(PC) — 高值通常表示可疑。
- 评论的字数长度 — 低值通常表示可疑。
- 第一人称代词的比例,如“I”,“my”(PP1) — 低值通常表示可疑。
- 感叹句的比例(RES) — 高值通常表示可疑。
- 主观词汇的比例 — 由sentiWordNet(SW)检测 — 高值通常表示可疑。
- 客观词汇的比例 — 由sentiWordNet(OW)检测 — 低值通常表示可疑。
- 评论频率 — 通过局部敏感哈希(F)近似 — 高值通常表示可疑。
- 基于单字和双字的描述长度(DLu,DLb) — 低值通常表示可疑。
4.2.2 探索性数据分析
在本节中,我们将下载并探索数据集,重点分析节点特征。节点特征将作为我们非图基准模型中的主要表格特征。
数据集可以从Yingtong Dou的GitHub仓库(mng.bz/Pdyg)下载,压缩为zip文件。解压后,文件将是MATLAB格式。我们可以使用scipy库中的loadmat函数和Dou仓库中的实用函数来加载数据,并生成我们需要的对象(见列表4.1):
- 一个包含节点特征的
features对象 - 一个包含节点标签的
labels对象 - 一个邻接列表对象
列表4.1:加载数据
prefix = 'PATH_TO_MATLAB_FILE/'
data_file = loadmat(prefix + 'YelpChi.mat') #1
labels = data_file['label'].flatten() #2
features = data_file['features'].todense().A #2
yelp_homo = data_file['homo'] #3
sparse_to_adjlist(yelp_homo, prefix + 'yelp_homo_adjlists.pickle')
- #1
loadmat是一个scipy函数,用于加载MATLAB文件。 - #2 获取节点的标签和特征。
- #3 获取并保存邻接列表。“Homo”表示该邻接列表将基于同质的边集,即我们去除了边的多关系性质。
一旦邻接列表被提取并保存,它可以在以后通过以下代码调用:
with open(prefix + 'yelp_homo_adjlists.pickle', 'rb') as file:
homogenous = pickle.load(file)
加载完数据后,我们可以进行一些探索性数据分析(EDA),分析图结构和节点特征。
4.2.3 探索图结构
为了更好地理解数据集中的欺诈行为,我们探索了潜在的图结构。通过分析连接组件和各种图度量,我们可以概览网络的拓扑结构。这种理解将揭示数据固有的特征,并确保没有潜在的障碍影响有效的GNN训练。我们将详细分析连接组件、密度、聚类系数和其他关键度量。
为了执行这种结构化的EDA,我们使用邻接列表通过NetworkX库来分析图的结构。在以下代码片段中,我们加载了邻接列表对象,将其转换为NetworkX图对象,然后查询该图对象的三个基本属性。更长的代码可以在仓库中找到:
with open(prefix + 'yelp_homo_adjlists.pickle', 'rb') as file:
homogenous = pickle.load(file)
g = nx.Graph(homogenous)
print(f'Number of nodes: {g.number_of_nodes()}')
print(f'Number of edges: {g.number_of_edges()}')
print(f'Average node degree: {len(g.edges) / len(g.nodes):.2f}')
从EDA中,我们获得了表4.3所列的属性。
表4.3:图属性
| 属性 | 值/细节 |
|---|---|
| 节点数量 | 45,954 |
| 边数量 | 3,892,933 |
| 平均节点度 | 84.71 |
| 密度 | ~0.00 |
| 连通性 | 图不是连通的 |
| 平均聚类系数 | 0.77 |
| 连接组件数量 | 26 |
| 节点度分布(前10个节点) | [4, 4, 4, 3, 4, 5, 5, 6, 5, 19] |
让我们深入分析这些属性。图的规模相对较大,包含45,954个节点和3,892,933条边。这意味着该图具有相当复杂的结构,可能包含复杂的关系。平均节点度为84.71,表明图中的节点平均连接约85个其他节点。这表明图中的节点连接较为紧密,可能存在丰富的信息流动。图的密度接近0.00,这表明图非常稀疏。换句话说,实际连接(边)的数量远低于可能的连接数。图的密度是边的数量除以可能的边数。
图不是完全连通的,包含26个独立的连接组件。多个连接组件的存在可能需要在建模时特殊考虑,尤其是当不同组件代表不同的数据簇或现象时。平均聚类系数为0.77,表示图具有较强的“团体性”。高值意味着节点倾向于聚集在一起,形成紧密的群体。这可能表明数据中存在局部社区或簇,这对于理解模式或异常,尤其是欺诈检测非常重要。
鉴于我们有26个独立的组件,重要的是检查这些组件的情况,以便规划模型训练。我们想知道这些组件的大小是否相似,是否存在不同大小的混合,或是否有一个或两个组件占主导地位。我们对26个组件进行类似的分析,并在表4.4中总结了它们的属性,按节点数量降序排列。第一列显示了组件的标识符。从表中可以看出,组件3在数据集中占主导地位。
表4.4:26个图组件的属性,按节点数量降序排列
| 组件ID | 节点数量 | 边数量 | 平均节点度 | 密度 | 平均聚类系数 |
|---|---|---|---|---|---|
| 3 | 45,900 | 38,92810 | 169.62 | 0 | 0.77 |
| 4 | 13 | 60 | 9.23 | 0.77 | 0.77 |
| 2 | 6 | 14 | 4.67 | 0.93 | 0.58 |
| 1, 22 | 3 | 6 | 4 | 0 | 0 |
| 5–9, 14, 17, 24, 26 | 2 | 3 | 3 | 0 | 0 |
在最后三行中,多个组件具有相同的属性,因此合并到同一行以节省空间。
我们看到,组件3是占主导地位的组件,其后是25个较小的组件。这些小组件可能对我们的模型影响不大,因此我们将重点关注组件3。我们将组件3与整个图进行对比,见表4.5。大多数属性非常相似或相同,唯一的例外是平均节点度,组件3的平均节点度是整个图的两倍。
表4.5:将图的最大组件(组件3)与整体图进行比较
| 属性 | 组件3 | 整体图 | 见解/对比 |
|---|---|---|---|
| 节点数量 | 45,900 | 45,954 | 组件3包含了几乎整个图的节点。 |
| 边数量 | 3,892,810 | 3,892,933 | 组件3几乎包含了整个图的所有边。 |
| 平均节点度 | 169.62 | 84.71 | 组件3的节点比整体图的节点更紧密连接。 |
| 密度 | 0.00 | 0.00 | 组件3和整个图都是稀疏的;这个特性主要由组件3决定。 |
| 平均聚类系数 | 0.77 | 0.77 | 组件3的聚类系数与整体图相匹配,表明其在定义图结构中的主导地位。 |
对于我们的GNN建模目的,结构分析的主要启示是:组件3在节点和边的数量上占主导地位,强调了其在数据集中的重要性;几乎整个图的结构都封装在这个单一组件中。这表明,组件3中的模式、关系和异常将强烈影响模型的训练和结果。组件3的平均节点度高于整体图,表明其具有更丰富的互联性,这突出了有效捕捉这些密集连接的重要性。此外,组件3与整体图相同的密度和聚类系数值,强调了该组件在数据集整体结构属性中的代表性。
我们有两个选择:
- 假设其他组件对模型影响较小,直接训练而不做任何调整。
- 仅对组件3进行建模,完全排除小组件的数据进行训练和测试。
我们通过对图数据结构属性的分析,深入了解了图的特性,并获得了有价值的见解,指导了GNN模型设计和训练,帮助我们理解潜在的欺诈模式。接下来,我们将深入研究节点特征。
4.2.4 探索节点特征
在探索了图的结构特性之后,我们转向节点特征。在本节开始时的代码中,我们从数据文件中提取了节点特征:
features = data_file['features'].todense().A
注:如前所述,这些特征定义是由Rayana等人手工制作的【7,8】。在特征生成过程中,Dou等人【8】进一步处理了Yelp评论数据集,创建了一组标准化的节点特征。
通过一些额外的工作(见代码仓库),我们还为这些特征添加了一些标签和描述,然后为每个特征创建了分布图(示例图见图4.3到图4.5)。每组图表对应于描述评论文本、评论者和商家的特征。我们希望通过这些图表检查节点特征是否能有效区分虚假评论。图4.3展示了从评论特征派生的两个特征的分布。
图4.4展示了从评论者特征派生的两个特征的分布。
最后,图4.5展示了从被评论餐厅或酒店特征派生的两个特征的分布。
通过检查32个节点特征的直方图,我们可以做出几个观察。首先,许多特征呈现明显的偏斜。具体而言,像Rank、RD和EXT等特征呈现右偏分布。这表明大多数数据点集中在直方图的左侧,但少数较高的值将直方图拉向右侧。相反,像MNR_user、PR_user和NR_user等特征则显示出左偏分布。在这些情况下,大部分数据点集中在直方图的右侧,少数较低的值将直方图拉向左侧。
一些特征还表现出双峰分布,意味着数据中存在两个明显的峰值或群体。这表明,将数据分段并为每个群体创建单独的模型可能是一个有用的策略。
最后,几个直方图的长尾表明存在一些异常值。考虑到某些模型(如线性回归)对极端值非常敏感,处理这些异常值可能对优化和改进模型至关重要。这可能意味着选择抗异常值的模型,开发减轻异常值影响的策略,甚至完全去除它们。
基于这些总体见解,让我们更仔细地检查其中一个特征图。PP1是评论中第一人称代词(即I、me、us、our等)与第二人称代词(you、your等)的比例。该特征的开发源于观察到垃圾评论通常包含更多的第二人称代词。从PP1的分布图来看,我们观察到其分布呈左偏,尾部集中在低值处。因此,如果低比例是垃圾评论的一个指示器,这个特征在区分垃圾评论时会非常有效。
为了总结我们对节点特征的探索,这些数据展示了多样的特征,提供了许多用于模型训练的机会。进一步的预处理,如异常值处理、偏斜特征转换、数据分段和特征缩放,可能对优化模型的预测性能至关重要。
我们对评论垃圾邮件数据集的探索揭示了一些模式、异常和见解。从数据集的复杂结构特征(主要由占主导地位的组件3表示)到提供有效区分真实评论和虚假评论的节点特征,我们为模型训练奠定了基础。
在4.3节中,我们将开始训练我们的基准模型。这些初始模型作为基础,帮助我们评估基本模型性能的有效性。通过这些模型,我们将利用数据的图结构和节点特征的潜力,将垃圾评论与真实评论区分开来。
4.3 训练基准模型
在我们的数据集基础上,我们将通过首先开发三个基准模型来开始训练阶段:逻辑回归、XGBoost和MLP。请注意,对于这些模型,数据将采用表格格式,节点特征作为我们的列特征。每个节点的数据集将对应一行或一个观测值。接下来,我们将开发一个额外的GNN基准,通过训练一个GCN来评估引入图结构数据对问题的影响。
现在,我们将我们的表格数据分为测试集和训练集,并应用这三个基准模型。首先是测试集/训练集的划分:
from sklearn.model_selection import train_test_split
split = 0.2
xtrain, xtest, ytrain, ytest = train_test_split(
features, labels, test_size=split, stratify=labels, random_state=99) #1
print(f'Required shape is {int(len(features)*(1-split))}') #2
print(f'xtrain shape = {xtrain.shape}, xtest shape = {xtest.shape}')
print(f'Correct split = {int(len(features)*(1-split)) == xtrain.shape[0]}')
- #1 数据以80/20的比例分为训练集和测试集
- #2 检查数据对象的形状是否正确
我们可以使用这些分割后的数据来训练每个基准模型。在这个训练过程中,我们只使用了节点特征和标签,没有使用图数据结构或几何结构。对于基准模型和GNN,我们主要依赖接收者操作特征(ROC)和曲线下面积(AUC)来衡量性能,并与我们的GAT模型进行对比。
4.3.1 非GNN基准模型
我们首先使用scikit-learn实现的逻辑回归模型,并使用默认的超参数:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score, f1_score
clf = LogisticRegression(random_state=0).fit(xtrain, ytrain) #1
ypred = clf.predict_proba(xtest)[:, 1]
acc = roc_auc_score(ytest, ypred) #2
print(f"Model accuracy (logression) = {100*acc:.2f}%")
- #1 逻辑回归模型的实例化和训练
- #2 计算AUC得分
该模型的AUC为76.12%。对于ROC性能,我们还将使用scikit-learn中的函数。我们还会使用真正率(tpr)和假正率(fpr)来与其他基准模型进行比较:
from sklearn.metrics import roc_curve
fpr, tpr, _ = roc_curve(ytest, ypred) #1
plt.figure(1)
plt.plot([0, 1], [0, 1])
plt.plot(fpr, tpr)
plt.xlabel('False positive rate')
plt.ylabel('True positive rate')
plt.show()
- #1 计算ROC曲线,得到假正率(fpr)和真正率(tpr)
在图4.6中,我们看到ROC曲线。我们发现,曲线在假正率和假负率之间相对平衡,但由于曲线靠近对角线,整体特异性较差。
XGBoost
XGBoost基准模型紧接着逻辑回归模型,如列表4.2所示。我们使用一个简单的模型,使用相同的训练集和测试集进行训练。为了便于比较,我们将生成的预测命名为pred2,并区分真正率(tpr2)和假正率(fpr2)。
列表4.2:XGBoost基准模型及图形
import xgboost as xgb
xgb_classifier = xgb.XGBClassifier()
xgb_classifier.fit(xtrain, ytrain)
ypred2 = xgb_classifier.predict_proba(xtest)[:,1] #1
acc = roc_auc_score(ytest, ypred2)
print(f"Model accuracy (XGBoost) = {100*acc:.2f}%")
fpr2, tpr2, _ = roc_curve(ytest, ypred2) #2
plt.figure(1)
plt.plot([0, 1], [0, 1])
plt.plot(fpr, tpr)
plt.plot(fpr2, tpr2)
plt.xlabel('False positive rate')
plt.ylabel('True positive rate')
plt.show()
- #1 为了对比,我们将XGBoost的预测结果命名为
ypred2。 - #2 为了对比,我们区分XGBoost的真正率(tpr)和假正率(fpr),并将其与逻辑回归结果一起绘制。
图4.7展示了XGBoost和逻辑回归的ROC曲线。从图中可以明显看出,XGBoost在这个指标上具有更优的表现。
XGBoost在这组数据上表现优于逻辑回归,获得了94%的AUC,并且具有更优的ROC曲线。这突显了即使是一个简单的模型,也能适用于某些问题,并且检查性能总是一个好主意。
多层感知机(MLP)
对于MLP基准模型,我们使用PyTorch构建一个简单的三层模型,如列表4.3所示。与PyTorch类似,我们通过定义一个类来建立模型,定义各个层和前向传播。在MLP中,我们使用二元交叉熵(BCE)作为损失函数,这是二分类问题中常用的损失函数。
列表4.3:MLP基准模型及图形
import torch #1
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module): #2
def __init__(self, in_channels, out_channels, hidden_channels=[128,256]):
super(MLP, self).__init__()
self.lin1 = nn.Linear(in_channels, hidden_channels[0])
self.lin2 = nn.Linear(hidden_channels[0], hidden_channels[1])
self.lin3 = nn.Linear(hidden_channels[1], out_channels)
def forward(self, x):
x = self.lin1(x)
x = F.relu(x)
x = self.lin2(x)
x = F.relu(x)
x = self.lin3(x)
x = torch.sigmoid(x)
return x
model = MLP(in_channels = features.shape[1], out_channels = 1) #3
epochs = 100 #4
lr = 0.001
wd = 5e-4
n_classes = 2
n_samples = len(ytrain)
w = ytrain.sum() / (n_samples - ytrain.sum()) #5
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd) #6
criterion = torch.nn.BCELoss() #7
xtrain = torch.tensor(xtrain).float() #8
ytrain = torch.tensor(ytrain)
losses = []
for epoch in range(epochs): #9
model.train()
optimizer.zero_grad()
output = model(xtrain)
loss = criterion(output, ytrain.reshape(-1, 1).float())
loss.backward()
losses.append(loss.item())
ypred3 = model(torch.tensor(xtest, dtype=torch.float32))
acc = roc_auc_score(ytest, ypred3.detach().numpy())
print(f'Epoch {epoch} | Loss {loss.item():6.2f} | Accuracy = {100*acc:6.3f}% | # True Labels = {ypred3.detach().numpy().round().sum()}', end='\r')
optimizer.step()
fpr, tpr, _ = roc_curve(ytest, ypred)
fpr3, tpr3, _ = roc_curve(ytest, ypred3.detach().numpy()) #10
plt.figure(1) #11
plt.plot([0, 1], [0, 1])
plt.plot(fpr, tpr)
plt.plot(fpr2, tpr2)
plt.plot(fpr3, tpr3)
plt.xlabel('False positive rate')
plt.ylabel('True positive rate')
plt.show()
- #1 导入本节所需的包。
- #2 使用类定义MLP架构。
- #3 实例化已定义的模型。
- #4 设置关键超参数。
- #5 解决类别不平衡问题。
- #6 定义优化器和训练准则。
- #7 使用BCE损失作为损失函数。
- #8 将训练数据转换为PyTorch数据类型:torch张量。
- #9 训练循环。在此示例中,我们指定了100个epoch。
- #10 区分tpr和fpr以进行比较。
- #11 将三个ROC曲线一起绘制。
图4.8展示了逻辑回归、XGBoost和MLP的ROC结果。
MLP模型经过100个周期的训练,准确率为85.9%,位于我们的基准模型之中。它的ROC曲线仅略优于逻辑回归模型。这些结果在表4.6中总结。
表4.6 三个基准模型的对数损失和ROC AUC
| 模型 | 对数损失 | ROC AUC |
|---|---|---|
| 逻辑回归 | 0.357 | 75.90% |
| XGBoost | 0.178 | 94.17% |
| 多层感知机(MLP) | 0.295 | 85.93% |
总结这一部分,我们运行了三个基准模型作为我们的GNN模型的对比基准。这些基准模型没有使用结构化图数据,而只是使用了从节点特征中衍生出来的表格特征。我们没有对这些模型进行优化,结果显示XGBoost表现最佳,准确率为89.25%。接下来,我们将训练一个使用GCN的基准模型,并应用GAT模型。
4.3.2 GCN基准模型
在这一部分中,我们将应用GNNs解决我们的任务,首先使用第三章中的GCN模型,然后再应用GAT模型。我们预计由于图的结构数据,GNN模型会优于其他基准模型,而具有注意力机制的模型将表现最佳。对于GNN模型,我们需要对数据处理流程进行一些修改。许多修改与数据预处理和数据加载有关。
数据预处理
一个关键的第一步是为我们的GNN准备数据。这些步骤已在第二章和第三章中介绍过。代码如列表4.4所示,执行以下步骤:
- 建立训练集/测试集划分。我们使用之前的
test_train_split函数,稍作修改以产生索引,并只保留结果索引。 - 将数据集转换为PyG张量。为此,我们从前面生成的同质邻接列表开始,使用NetworkX将其转换为NetworkX图对象。然后,我们使用PyG的
from_networkx函数将其转换为PyG数据对象。 - 将训练集/测试集划分应用于转换后的数据对象。为此,我们使用第一步中的索引。
我们希望展示多种方式来安排训练数据的摄取方式。因此,对于GCN,我们将整个数据集输入模型,而在GAT的示例中,我们将训练数据批处理。
列表4.4 转换训练数据的类型
from torch_geometric.transforms import NormalizeFeatures
split = 0.2 #1
indices = np.arange(len(features)) #1
xtrain, xtest, ytrain, ytest, idxtrain, idxtest\
= train_test_split(features labels,indices, \
stratify=labels, test_size = split, \
random_state = 99) #2
g = nx.Graph(homogenous) #3
print(f'节点数: {g.number_of_nodes()}')
print(f'边数: {g.number_of_edges()}')
print(f'平均节点度: {len(g.edges) / len(g.nodes):.2f}')
data = from_networkx(g)
data.x = torch.tensor(features).float()
data.y = torch.tensor(labels)
data.num_node_features = data.x.shape[-1]
data.num_classes = 1 #二分类
A = set(range(len(labels))) #4
data.train_mask = torch.tensor([x in idxtrain for x in A]) #4
data.test_mask = torch.tensor([x in idxtest for x in A]) #4
#1 建立训练集/测试集划分。我们只使用索引变量。
#2 建立训练集/测试集划分。我们只使用索引变量。
#3 将邻接列表转换为PyG数据对象
#4 在数据对象中建立训练集/测试集划分
数据预处理完成后,我们准备应用GCN和GAT模型。在第三章中,我们详细介绍了GCN架构。在列表4.5中,我们建立了一个两层的GCN,训练1000个周期。我们选择两层是因为第三章中的洞察表明,一般来说,较浅的模型深度可以提高性能并防止过度平滑。
列表4.5 GCN定义和训练
class GCN(torch.nn.Module): #1
def __init__(self, hidden_layers = 64):
super().__init__()
torch.manual_seed(2022)
self.conv1 = GCNConv(data.num_node_features, hidden_layers)
self.conv2 = GCNConv(hidden_layers, 1)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return torch.sigmoid(x)
device = torch.device("cuda"\
if torch.cuda.is_available() \
else "cpu") #2
print(device)
model = GCN()
model.to(device)
data.to(device)
lr = 0.01
epochs = 1000
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
criterion = torch.nn.BCELoss()
losses = []
for e in range(epochs): #3
model.train()
optimizer.zero_grad()
out = model(data) #4
loss = criterion(out[data.train_mask], \
data.y[data.train_mask].\
reshape(-1,1).float())
loss.backward()
losses.append(loss.item())
optimizer.step()
ypred = model(data).clone().cpu()
pred = data.y[data.test_mask].clone().cpu().detach().numpy()
true = ypred[data.test_mask].detach().numpy()
acc = roc_auc_score(pred,true)
print(f'第{e}个周期 | 损失 {loss:6.2f} \
| 准确率 = {100*acc:6.3f}% \
| # 真标签 =\ {ypred.round().sum()}')
fpr, tpr, _ = roc_curve(pred,true) #5
#1 定义一个两层的GCN架构
#2 初始化模型并将模型和数据移到GPU
#3 训练循环
#4 对每个周期,我们将整个数据对象输入模型,然后使用训练掩码计算损失。
#5 计算假阳性率(fpr)和真正率(tpr)
应用解决方案
需要注意的一点是我们在训练中使用了掩码。虽然我们使用训练掩码中的节点来建立损失,但在前向传播时,我们必须将整个图传递通过模型。这是为什么呢?与传统的独立数据点(例如,表格数据集中的行)处理的机器学习模型不同,GNN(图神经网络)在图结构数据上工作,其中节点之间的关系至关重要。在训练GCN时,每个节点的嵌入是基于其邻居的信息来更新的。由于这个消息传递过程涉及从节点的局部邻域聚合信息,模型需要访问整个图结构,以便能够正确计算这些聚合,并准确地执行这一过程。
因此,在训练过程中,即使我们只对某些节点(即训练集中的节点)的预测感兴趣,将整个图传递通过模型也能确保考虑到所有必要的上下文。如果只将图的一部分传递通过模型,网络将缺乏传播消息所需的完整信息,无法有效地更新节点表示。
GCN的100个周期训练会得到94.37%的准确率。引入图数据后,我们看到了相较于XGBoost模型的逐步改进。表4.7对比了四个基准模型的性能水平。
表4.7 四个基准模型的AUC
| 模型 | AUC |
|---|---|
| 逻辑回归 | 75.90% |
| XGBoost | 94.17% |
| 多层感知机(MLP) | 85.93% |
| GCN | 94.37% |
总结来说,我们看到,使用GNN模型引入图的结构信息,与纯特征或表格模型相比,性能略有提升。显然,XGBoost模型即使没有使用图结构,仍然表现出了令人印象深刻的结果。然而,GCN模型略微更好的表现突显了GNN在使用图数据中嵌入的关系信息的潜力。
在我们研究的下一阶段,注意力将转向图注意力网络(GAT)。GAT具有一个特别针对学习如何在消息传递步骤中对邻居的重要性进行加权的注意力机制,这可能会带来更好的模型性能。在接下来的部分,我们将深入探讨GAT模型的训练细节,并将其结果与我们建立的基准进行比较。让我们继续进行GAT模型训练。
4.4 训练GAT模型
为了训练我们的GAT模型,我们将应用两个PyG实现(GAT和GATv2)[2]。在这一部分,我们将直接进入模型训练,而不讨论注意力机制在机器学习模型中的含义及其为何有用。不过,关于注意力机制的简要概述及其可能的优势,见第4.5节。
我们将训练两个不同的GAT模型。这两个模型遵循相同的基本思路——即用注意力机制替换我们GCN中的聚合操作符,以学习模型应当关注哪些消息(节点特征)。第一个——GATConv——是对第三章GCN的简单扩展,加入了注意力机制。第二个是该模型的轻微变种,称为GATv2Conv。该模型与GATConv相同,唯一的区别是它解决了原始实现中的一个限制,即注意力机制在每个GNN层之间是静态的。而在GATv2Conv中,注意力机制在各层之间是动态的。
再次强调,原始的GAT模型每次训练循环只计算一次注意力权重,使用单独的节点和邻域特征,这些权重在所有层中是静态的。而在GATv2中,注意力权重是基于节点特征在通过层转换时计算的。这使得GATv2更加具有表达性,能够学习在训练过程中强调节点邻域的影响。
由于引入了注意力机制,这两个模型会增加显著的计算开销。为了解决这个问题,我们引入了小批量训练到我们的训练循环中。
4.4.1 邻域加载器和GAT模型
从实现角度来看,之前研究的卷积模型与我们的GAT模型之间的一个关键区别是,GAT模型的内存需求要大得多[9]。原因在于,GAT需要计算每个注意力头和每条边的注意力分数。这反过来要求PyTorch的自动梯度方法在内存中保存张量,内存需求会根据边的数量、头的数量和(节点特征数量的两倍)大幅增加。
为了解决这个问题,我们可以将图分成批次,并将这些批次加载到训练循环中。这与我们GCN模型的训练方式不同,GCN模型是基于一个批次(整个图)进行训练的。PyG的NeighborLoader(在其dataloader模块中)支持这样的批量训练,以下是实现代码(NeighborLoader基于“Inductive Representation Learning on Large Graphs”论文[10])。NeighborLoader的关键输入参数是:
num_neighbors——每个节点将被采样的邻居节点数,乘以迭代次数(即GNN层数)。在我们的示例中,我们指定了在两次迭代中采样1000个节点。batch_size——每个批次中选择的节点数。在我们的示例中,我们将批次大小设置为128。
列表4.6 设置NeighborLoader以进行GAT训练
from torch_geometric.loader import NeighborLoader
batch_size = 128
loader = NeighborLoader(
data,
num_neighbors=[1000]*2, #1
batch_size=batch_size, #2
input_nodes=data.train_mask)
sampled_data = next(iter(loader))
print(f'检查批次大小是否为 {batch_size}: {batch_size == sampled_data.batch_size}')
print(f'批次中的欺诈比例: {100*sampled_data.y.sum()/len(sampled_data.y):.4f}%')
sampled_data
#1 对每个节点在两次迭代中采样1000个邻居
#2 使用批次大小进行训练节点采样
在创建GAT模型时,相比于我们的GCN类,主要有两个关键的变化。首先,由于我们是批量训练,我们希望应用一个批量归一化层。批量归一化是一种技术,用于将每一层的输入标准化,使其均值为0,标准差为1。这有助于稳定并加速训练过程,通过减少内部协变量的偏移,允许使用更高的学习率,并提高模型的整体性能。
其次,我们注意到我们的GAT层有一个额外的输入参数——heads,即多头注意力的数量。在我们的示例中,我们的第一个GATConv层有两个头,如列表4.7所示。第二个GATConv层是输出层,只有一个头。在这个GAT模型中,由于我们希望最终层为每个节点生成一个单一的表示,使用一个头。多个头将导致输出混乱,生成多个节点表示。
列表4.7 基于GAT的架构
class GAT(torch.nn.Module):
def __init__(self, hidden_layers=32, heads=1, dropout_p=0.0):
super().__init__()
torch.manual_seed(2022)
self.conv1 = GATConv(data.num_node_features, \
hidden_layers, heads, dropout=dropout_p) #1
self.bn1 = nn.BatchNorm1d(hidden_layers * heads) #2
self.conv2 = GATConv(hidden_layers * heads, \
1, dropout=dropout_p)
def forward(self, data, dropout_p=0.0):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = self.bn1(x)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return torch.sigmoid(x)
#1 GAT层有一个`heads`参数,决定每层中注意力机制的数量。在此实现中,第一层(conv1)使用多个头以提取更丰富的特征,而最终输出层(conv2)使用一个头,将学习到的信息聚合成每个节点的单一输出。
#2 由于进行批量训练,添加了批量归一化层。
我们的GAT训练循环与单批次GCN训练循环类似,以下是实现代码,唯一的区别是现在需要为每个批次使用嵌套循环。
列表4.8 GAT训练循环
lr = 0.01
epochs = 1000
model = GAT(hidden_layers=64, heads=2)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
criterion = torch.nn.BCELoss()
losses = []
for e in range(epochs):
epoch_loss = 0.
for i, sampled_data in enumerate(loader): #1
sampled_data.to(device)
model.train()
optimizer.zero_grad()
out = model(sampled_data)
loss = criterion(out[sampled_data.train_mask],\
sampled_data.y[sampled_data.train_mask].reshape(-1,1).float())
loss.backward()
epoch_loss += loss.item()
optimizer.step()
ypred = model(sampled_data).clone().cpu()
pred = sampled_data.y[sampled_data.test_mask].clone().cpu().detach().numpy()
true = ypred[sampled_data.test_mask].detach().numpy()
acc = roc_auc_score(pred, true)
losses.append(epoch_loss / batch_size)
print(f'第{e}个周期 | 损失 {epoch_loss:6.2f} | 准确率 = {100*acc:6.3f}% | 真标签数 = {ypred.round().sum()}')
#1 为每个批次训练添加了嵌套循环。每次迭代是`NeighborLoader`加载的一批节点。
GATConv和GATv2Conv的训练步骤相同,代码可以在我们的仓库中找到。训练GATConv和GATv2Conv分别得到了95.65%和95.10%的准确率。如表4.8所示,我们的GAT模型超越了基准模型和GCN模型。图4.9显示了GCN和GAT模型的ROC结果。图4.10显示了GCN、GAT和GATv2模型的ROC结果。
表4.8 各模型的ROC AUC
| 模型 | ROC AUC (%) |
|---|---|
| 逻辑回归 | 75.90 |
| XGBoost | 94.17 |
| 多层感知机(MLP) | 85.93 |
| GCN | 94.37 |
| GAT | 95.65 |
| GATv2 | 95.10 |
在观察ROC曲线时,我们看到两个GAT模型都优于GCN。我们还发现,它们的假阳性率也更低。这对于欺诈/垃圾邮件检测至关重要,因为假阳性会导致真实的交易/用户被错误标记,造成不便并丧失信任。对于GATv2,我们注意到其真正率的表现与GCN和GAT相同。这表明,虽然GATv2在避免将真实交易误标为欺诈时显得保守,但它可能错过一些真实的欺诈行为。这些洞察可以帮助我们改进模型或做出决策,选择使用哪个模型。尽管AUC曲线和得分有利,我们还必须解决影响我们GAT模型可用性的一个最终问题:类别不平衡。
4.4.2 解决模型性能中的类别不平衡问题
类别不平衡是GNN问题中的一个关键挑战,少数类(通常代表罕见但重要的实例,如欺诈行为)与多数类相比显著不足。在我们的数据集中,只有14.5%的节点被标记为欺诈,这使得模型从这些稀疏数据中有效学习变得具有挑战性。虽然较高的AUC分数可能表明整体表现良好,但它们可能具有误导性,掩盖了少数类的较差表现,而少数类在平衡评估中至关重要。更深入的分析揭示了一个关键的疏忽:类别不平衡显著影响了我们的精度和F1分数。
为应对这一挑战,已经专门为GNN开发了几种方法来解决类别不平衡问题。传统技术如合成少数类过采样技术(SMOTE)已经被调整为创建图特定的方法,如GraphSMOTE,它生成合成节点和边来平衡类别分布,而不破坏图结构。其他方法包括重采样技术(过采样和欠采样)、成本敏感学习、架构修改和专注于少数类特征的注意力机制[11, 12]。
尽管这些方法有助于提高模型性能,但它们也带来了独特的挑战,例如保持图的拓扑结构、维护节点之间的依赖关系以及确保可扩展性。最近的进展,如图-图神经网络(G2GNN),已经被开发出来,更有效地处理这些问题。通过理解和应用这些策略,我们可以增强GNN模型在现实应用中的鲁棒性和公平性,在这些应用中类别不平衡是常见问题。以前一节中的GATv2模型为例,我们将其F1、召回率和精度与XGBoost进行了比较,结果见表4.9。XGBoost表现更好,而GATv2在处理不平衡数据时表现不佳。
表4.9 比较本章中训练的GATv2与XGBoost模型的F1、召回率和精度
| 指标 | GATv2 | XGBoost |
|---|---|---|
| F1分数 | 0.254 | 0.734 |
| 精度 | 0.145 | 0.855 |
| 召回率 | 1 | 0.643 |
GATv2模型的表现反映了在类别极度不平衡的情况下常见的挑战。由于少数类仅占数据的14.5%,模型重点在于最大化召回率,达到了完美的召回率1.000。这表明模型能够正确识别每一个少数类实例,避免漏检潜在的重要案例。然而,这也以精度的显著降低为代价,精度仅为0.145。这表明,尽管GAT模型有效地检测到了所有真正的阳性,但它也将许多负类误分类为正类,从而导致了大量的假阳性。因此,F1分数(反映精度和召回率的综合指标)较低,仅为0.254,突显了模型在平衡检测与准确性方面的低效。
为了解决这一问题,我们实施了两种策略来缓解类别不平衡:SMOTE(如图4.11所示)和自定义重排方法(如图4.12所示)。
SMOTE被用来生成合成节点,反映原始数据集的平均度特征,并通过人工增强少数类的表示。重排方法采用了不同的策略,通过避免生成合成数据,确保每个训练批次中类别的平衡。具体做法是通过将多数类数据重新分配到各个批次中,使用BalancedNodeSampler类来实现这一点。该类保证每个批次中,来自多数类和少数类的节点数量相等。对于每个批次,采样器随机选择一个平衡的节点集合,提取相应的子图,并重新索引节点以保持一致性。图4.12展示了这一过程中的典型批次重分配情况。该类的代码见列表4.9。
列表4.9 BalancedNodeSampler类
class BalancedNodeSampler(BaseSampler):
def __init__(self, data, num_samples=None):
super().__init__()
self.data = data
self.num_samples = num_samples #1
def sample_from_nodes(self, index, **kwargs):
majority_indices = torch.\
where(self.data.y == 0)[0] #2
minority_indices = torch.\
where(self.data.y == 1)[0] #3
if self.num_samples is None:
batch_size = min(len(majority_indices),\
len(minority_indices)) #4
else:
batch_size = self.num_samples // 2
majority_sample = majority_indices[torch.randperm\
(len(majority_indices))[:batch_size]] #5
minority_sample = minority_indices[torch.randint\
(len(minority_indices), (batch_size,))]
batch_indices = torch.cat\
((majority_sample, minority_sample)) #6
mask = torch.zeros(self.data.num_nodes, dtype=torch.bool)
mask[batch_indices] = True #7
row, col = self.data.edge_index
mask_edges = mask[row] & mask[col] #8
sub_row = row[mask_edges]
sub_col = col[mask_edges]
new_index = torch.full((self.data.num_nodes,), -1, dtype=torch.long)
new_index[batch_indices] = \
torch.arange(batch_indices.size(0)) #9
sub_row = new_index[sub_row]
sub_col = new_index[sub_col]
return SamplerOutput(
node=batch_indices,
row=sub_row,
col=sub_col,
edge=None,
num_sampled_nodes=[len(batch_indices)],
metadata=(batch_indices, None)
)
#1 可选:定义每类的固定采样大小
#2 多数类的索引
#3 少数类的索引
#4 确定平衡批次的大小
#5 随机选择两个类的节点
#6 将两个类的样本合并为一个批次
#7 为采样的节点创建掩码
#8 过滤采样节点之间的边
#9 重新索引采样的节点
在这种情况下,SMOTE没有带来性能提升。因此,我们将重点关注应用重排方法的结果。表4.10中的指标显示,我们的干预不仅提高了模型的公平性,还通过更好地捕捉少数类来增强了模型的鲁棒性,而没有牺牲整体准确性。虽然重排方法的AUC未超过XGBoost(94.17%),但它很好地处理了类别不平衡问题,在F1、精度和召回率方面表现优越。
表4.10 比较使用类别重排方法训练的GATv2模型的F1、精度、召回率和AUC
| 指标 | 值 |
|---|---|
| 平均验证F1分数 | 0.809 |
| 平均验证精度 | 0.878 |
| 平均验证召回率 | 0.781 |
| 平均验证AUC | 0.914 |
4.4.3 在GAT和XGBoost之间做出选择
选择使用XGBoost还是GAT应该根据具体的用例需求和约束条件来决定。XGBoost提供了高效性和速度,这对于计算资源有限或需要快速模型训练的项目非常有利。然而,GAT提供了额外的好处,即深度集成节点关系数据,这对于那些节点间关系对于理解复杂数据模式至关重要的项目尤为重要。
GAT尤其有价值,因为它能够被集成到更广泛的深度学习框架中,提供增强的节点嵌入,封装丰富的上下文信息,从而使其适用于复杂的关系型数据集。
我们在解决类别不平衡方法上的探索显著地加深了我们对模型在现实场景中性能的理解。这些洞察对于有效开发稳健且有效的模型至关重要,尤其是在精度和召回率需要平衡的领域。在下一节中,我们将深入探讨GAT模型的基础概念。
4.5 引擎底层
在本节中,我们将讨论一些关于注意力和GAT模型的附加细节。这部分内容适合那些想了解底层原理的读者,如果你更感兴趣的是学习如何应用这些模型,可以跳过这一部分。我们将深入探讨GAT论文[8]中的方程式,并从更直观的角度解释注意力机制。
4.5.1 解释注意力和GAT模型
在本节中,我们将提供注意力机制的基础概述。我们将概念性地解释注意力、自注意力和多头注意力。然后,我们将GAT定位为卷积GNN的扩展。
概念1:各种注意力机制类型
注意力是过去十年中引入到深度学习中的最重要概念之一。它是现在著名的变换器(transformer)模型的基础,推动了许多生成模型突破,比如大型语言模型(LLMs)。注意力机制是指模型可以学习在训练过程中将额外的注意力集中在哪些方面[13, 14]。模型中有哪些不同类型的注意力机制呢?
注意力
假设你正在读一本小说,小说的情节不是线性的,而是跳跃式的,连接了各种角色、事件,甚至并行的故事情节。当你阅读有关特定角色的一章时,你会记得并考虑书中其他地方提到或出现过这个角色的部分。你对这个角色的理解在任何给定时刻都受到这些不同部分的影响。
在深度学习和GNN中,注意力起着类似的作用。当处理NLP问题中的一个句子时,注意力意味着模型能够学习邻近单词的重要性。对于GNN来说,在考虑图中的特定节点时,模型使用注意力来衡量邻居节点的重要性。这有助于模型决定哪些邻居节点在理解当前节点时最相关,就像你记住书中相关部分以更好地理解一个角色一样。
自注意力
假设你正在读一本提到多个角色和事件的句子,这些角色和事件之间可能有复杂的关系。要完全理解这个句子,你必须回想每个角色和事件是如何相互关联的,这些关联都在这个句子的范围内。你可能会发现自己更加关注那些对理解当前句子上下文至关重要的角色或事件。
对于使用自注意力的GNN来说,图中的每个节点不仅会考虑它的直接邻居,还会考虑它自身的特征和在图中的位置。通过这种方式,每个节点接收一个新的表示,这个表示受到自身和其他节点的加权上下文影响,有助于理解图中节点之间的复杂关系。
多头注意力
假设你是一个读书俱乐部的成员,大家都在读这本小说,并且每个俱乐部成员被要求专注于小说的不同方面——一个关注角色发展,另一个关注情节的转折,另一个则关注主题元素。当你们聚在一起讨论时,你获得了对这本书的多维理解。
同样,在GNN中,多头注意力允许模型有多个“头”或注意力机制,专注于邻居节点的不同方面或特征。这些不同的头可以学习图中不同的模式或关系,且它们的输出通常会被聚合,以形成对每个节点在大图中角色的更全面理解。
概念2:GAT作为卷积GNN的变种
GAT通过引入注意力机制扩展了卷积GNN。在传统的卷积GNN中,如GCN,在消息传递步骤中,来自所有邻居的贡献在聚合时都是平等加权的。然而,GAT在聚合函数中加入了注意力分数,用来加权这些贡献。这仍然是置换不变的(按设计),但比GCN中的求和操作更加具有描述性。
PyG实现
PyG提供了两种版本的GAT层。这两者的区别在于使用的注意力类型和注意力分数的计算:
- GATConv——基于Veličković的论文[1],该层使用自注意力来计算整个图的注意力分数。它也可以配置为使用多头注意力,从而使用多个“头”来专注于输入节点的不同方面。
- GATv2Conv——该层通过引入动态注意力改进了GATConv。在这里,自注意力分数在每一层的节点特定上下文中重新计算,使模型在学习如何加权节点表示时更加具有表达力,从而在每层GNN的消息传递步骤中处理节点表示。像GATConv一样,它支持多头注意力,从而更有效地捕获各种特征或方面。
与其他卷积GNN的权衡
在PyG中实现的GAT层由于使用了注意力机制,具有一定的优势。然而,也需要考虑性能的权衡。需要考虑的关键因素包括:
- 性能——GAT通常比标准的卷积GNN表现更好,因为它们可以集中关注最相关的特征。
- 训练时间——提高的性能伴随着更多的训练时间,因为计算注意力机制的复杂性增加。
- 可扩展性——计算成本也影响可扩展性,使得GAT不太适合非常大或稠密的图。
4.5.2 过度平滑
你已经学习了如何在消息传递步骤中改变聚合操作,加入更复杂的方法,如注意力机制。然而,在应用多轮消息传递时,通常会有性能退化的风险。这种现象被称为“过度平滑”,因为在多轮消息传递后,更新的特征可能会收敛到相似的值。图4.13展示了这一现象的一个示例。
如我们所知,消息传递发生在GNN的每一层。事实上,拥有更多层的GNN比拥有较少层的GNN更容易发生过度平滑现象。这也是为什么GNN通常比传统的深度学习模型更浅的原因之一。
过度平滑的另一个原因是在一个问题中存在显著的长程(按跳数计算)任务需要解决。例如,一个节点可能会受到远离的节点的影响。这也被称为拥有较大的“问题半径”。每当我们遇到一个图,其中的节点即使相距多跳,仍然能够对其他节点产生很大的影响时,就应该认为该问题的半径较大。例如,社交媒体网络可能会有较大的问题半径,因为某些个体,如名人,尽管彼此之间连接遥远,仍然能够影响其他个体。通常,当图足够大以至于存在远距离连接的节点时,这种情况才会发生。
通常来说,如果你认为某个问题可能面临过度平滑的风险,要小心你引入多少层到GNN中,即要注意其深度。然而,请注意,某些架构似乎比其他架构更不容易发生过度平滑。例如,GraphSAGE会对固定数量的邻居进行采样并聚合它们的信息,这种采样可以缓解过度平滑。另一方面,GCN由于没有这一采样过程,面临的风险更大,尽管注意力机制在一定程度上降低了这一风险,但GAT仍然可能受到过度平滑的影响,因为它的聚合仍然是局部的。
4.5.3 关键GAT方程概述
在本节中,我们将简要介绍Veličković等人[1]在GAT论文中给出的关键方程,并将它们与我们已经讨论的GAT概念联系起来。GAT通过使用注意力机制来学习在更新节点特征时哪些邻居节点更为重要。它们通过计算注意力分数(方程1-3),然后使用这些分数来加权和合并邻居节点的特征(方程4-6)。使用多头注意力增强了模型的表达力和鲁棒性,使其能够同时从多个角度学习。这种方法可能计算开销较大,但通常会提高GNN在节点分类、链路预测等任务上的性能。
注意力系数计算(方程4.1–4.3)
使用GAT的第一步是计算每对连接节点的注意力分数或系数。这些系数表示一个节点应该给其邻居多少“注意力”或重要性。原始的注意力分数[1]是这样计算的:
这里,eij表示从节点i到其邻居j的原始注意力分数: hi和hj是节点i和节点j的特征向量(表示)。 W是一个可学习的权重矩阵,用于将每个节点的特征线性转换到更高维的空间。 α是一个注意力机制(通常是一个神经网络),它计算每对节点的权重分数。 这个思想是评估节点i应该从节点j考虑多少信息。归一化的注意力系数[1]是这样计算的:
一旦我们得到原始分数eij,我们使用softmax函数对其进行归一化: αij表示归一化的注意力系数,量化了节点j的特征对节点i的重要性。 softmax确保给定节点i的所有注意力系数的总和为1,使得它们可以在不同节点之间进行比较。 以下是注意力系数[1]的详细计算:
这里,注意力机制α是通过一个具有参数a的单层前馈神经网络实现的。术语
涉及将节点i和节点j的转换特征向量进行拼接,然后应用线性变换,接着是非线性激活(泄漏修正线性单元 [leaky ReLU])。
节点表示更新(方程4.4–4.6)
在计算了注意力系数后,下一步是利用它们从邻居节点聚合信息,并使用注意力更新节点表示[1]:
这个方程计算了节点i的新表示hi':
- 术语
该术语表示邻居节点特征的加权和,其中每个特征向量的权重由其对应的注意力系数αij决定。 σ是一个非线性激活函数(如ReLU或sigmoid),它为模型引入非线性,帮助模型学习复杂的模式。 多头注意力机制[1]的计算公式为:
为了稳定学习过程,GAT使用了多头注意力机制,如前所述: 在这里,K个注意力头独立地计算不同的注意力系数集和相应的加权和。 所有头的结果被拼接在一起,形成一个更丰富、更具表现力的节点表示。 以下是最终层中多头注意力的平均计算方式[1]:
在网络的最终预测层中,我们不再将不同头的输出拼接在一起,而是取它们的平均值。这样可以减少最终输出的维度,并简化模型的预测过程。
总结
图注意力网络(GAT)是图神经网络(GNN)的一种特殊类型,它结合了注意力机制,在学习过程中专注于最相关的节点。 GAT在某些节点具有不成比例重要性的领域中表现优异,例如社交网络、欺诈检测和异常检测。 本章使用了一个来自Yelp评论的数据集,重点是检测芝加哥酒店和餐厅的虚假评论。评论被表示为节点,边表示共享特征(例如共同的作者或商家)。 GAT被应用于这个数据集,将节点(评论)分类为欺诈或合法。GAT模型相较于基准模型,如逻辑回归、XGBoost和图卷积网络(GCN),表现出改进。 由于GAT需要计算所有边的注意力分数,因此内存开销较大。为了解决这个问题,使用了PyTorch Geometric(PyG)中的NeighborLoader类进行小批量处理。 PyG中的GAT层,如GATConv和GATv2Conv,应用了不同类型的注意力来解决图学习问题。 可以采用SMOTE和类别重排等策略来解决类别不平衡问题。在我们的案例中,类别重排显著提高了模型性能。