TGAT实验记录

948 阅读1分钟

代码运行

github.com/ytchx1999/I…

github.com/StatsDLMath…

详细的数据集信息:github.com/srijankr/jo…

cd Inductive-representation-learning-on-temporal-graphs/

# data
cd processed/
wget http://snap.stanford.edu/jodie/reddit.csv  # reddit数据集
wget http://snap.stanford.edu/jodie/wikipedia.csv  # Wikipedia数据集
explore.ipynb # 探索数据集格式

# run
cd scripts/
# wikipedia
bash train_edge_wiki.sh

实现细节

数据格式

image.png

user和item统一编号,不区分不同类型的节点。每一行数据中的时间戳既代表src和dst节点的时间,也代表边发生的时间。

只给了边特征,节点特征没给,全零填充。

关于采样和minibatch

每个batch对“边” 进行正常采样和负采样:

  • 正常采样,抽取数据集中出现过的边的src和dst以及timestamp
  • 负采样,保留上一步的src,随机构造fake_dst,但是注意,dst也是带有时间戳的(在其他边中出现过)

优化思路

学习src_embed和dst_embed以及fake_dst_embed,连接预测转化为二分类问题,BCELoss。

  • concat(src_embed, dst_embed) --> FFN --> pos_prob
  • concat(src_embed, fake_dst_embed) --> FFN --> neg_prob

time encoding

ϕ(t,ti)=(cos(w1(tti)+b1),cos(w2(tti)+b2),...,cos(wd(tti)+bd))Rd\phi(t, t_i) = (cos(w_1(t-t_i)+b_1), cos(w_2(t-t_i)+b_2), ..., cos(w_d(t-t_i)+b_d)) \in R^d

t表示当前src节点的时间戳,t_i一般表示邻居dst节点的时间戳。

调试

{
    // 使用 IntelliSense 了解相关属性。 
    // 悬停以查看现有属性的描述。
    // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
    "version": "0.2.0",
    "configurations": [
        {
            "name": "Python: 当前文件",
            "type": "python",
            "request": "launch",
            "program": "${file}",
            "console": "integratedTerminal",
            "cwd": "${fileDirname}",
            "justMyCode": false,
            "args": [
            // wikipedia
            "-d", "wikipedia",
            "--bs", "200",
            "--uniform",
            "--n_degree", "20",
            "--agg_method", "attn",
            "--attn_mode", "prod",
            "--gpu", "0",
            "--n_head", "2",
            "--prefix", "hello_world",
            "--n_epoch", "1",
            ]
        }
    ]
}