TEMPORAL GRAPH NETWORKS FOR DEEP LEARNING ON DYNAMIC GRAPHS
目录
1、问题描述
TGN在TGAT的基础上增加一个记忆模块,以更好的利用之前的交互信息。
2、TGN
这个图比较抽象,按照程序的执行逻辑重新总结一下。
2.1 初始化memory/存储上一个batch更新后的memory
初始化memory,每个节点i都分配一个全0的向量:;2.6更新后的memory也存储在这里
初始化last_update_t,每个节点i都分配一个最近更新过的timestamp;2.6更新后的last_update_t也存储在这里
2.2 将memory和节点原始feature相加得到输入ndata-feat
2.3 embedding(TGAT)并计算loss
模型为TGAT的多头自注意力机制:TGAT阅读笔记,学习得到节点的embedding。
计算loss并进行反向传播。
2.4 message
首先把当前memory和last_update_t作为当前batch子图的节点的memory和时间戳,按照TGAT来计算每一条边上的消息:
2.5 aggregate
挑选最近的一条交互边(边时间戳最大)作为当前节点聚合后的message(也可以使用sum等传统聚合器)。
2.6 update并存储(回到2.1继续执行)
经过RNN/GRU单元的一步,得到新的隐藏表示,作为新的memory,
更新memory和last_update_t,回到2.1继续执行下一个batch。
| 问题 | 解释 |
|---|---|
| 为什么要使用记忆模块? | memory总是包含节点最近一次交互的消息等信息,起到了记忆功能,节点后面的交互和最近一次的交互有联系memory可以加到节点特征上作为输入,一定程度上起到了特征增强和的作用 |
| 为什么有了memory还要再学习一个embedding?为什么不直接用memory进行预测? | 如果只用memory作为节点表示的话,有可能会出现陈旧性记忆的问题,比如,用户好长时间不登录,这样,如果用户再次登录,如果其邻居比较活跃,也可以通过其邻居来快速学习当前的表示。 |
3、实验
3.1 数据集
用的JODIE的两个开源数据集和一个Twitter的私有工业数据集。
- Reddit和Wikipedia数据集描述见JODIE:github.com/srijankr/jo…
- Industrial数据集来自Twitter
3.2 baseline
静态方法:
- GAE
- VGAE
- Deepwalk
- node2vec
- GAT
- GraphSAGE
动态方法:
3.3 评估指标
- 动态连接预测:AP
- 动态节点分类:AUC
3.4 实验结果
和baseline的比较:
动态连接预测:
动态节点分类:
消融实验:
- 采样的邻居个数比较少的时候,有memory性能显著提升