【论文笔记】Deep Graph Infomax

655 阅读7分钟

Deep Graph Infomax

论文链接:Arxiv

论文代码:DGI

前置知识

论文中引入了互信息的概念,具体推导与证明不在这里讨论。

DGI将原始的图G=(X,A)生成了负样本图G'=(X',A),其中X代表特征矩阵,X'代表打乱行顺序的特征矩阵X,A代表邻接矩阵。

论文基于互信息的理论做了如下优化:将正样本图G与负样本图G'通过GCN编码分别得到正负embedding:H与H',并将H通过readout function聚合后得到特征s,s与H和H'同时输入判别器计算归属概率,其中H中每个节点embedding与s的概率要接近1,H'中每个节点embedding与s的概率要接近0。

论文阅读

论文理论知识较为硬核,但实际实现不困难,主要分为了几个部分:图编码器、Readout function、判别器以及优化策略,下面会逐一介绍。

图编码器

论文提出了一种无监督的图学习方法,首先针对给定的图G=(X,A),X代表节点的原始特征矩阵,A代表邻接矩阵,这一部分的目标就是学习一个编码器R,将N×N的特征矩阵编码为N×F的嵌入层,其中F为嵌入向量的维度。

在这里,由于构建了负样本图,因此论文顺带提及了构建负样本的方法,通过一个干扰函数C,将X的行索引打乱,生成了X',相对于将这些节点的特征向量又随机分配给了节点。

正负样本图共享一个GCN网络,分别得到H与H'。

Readout function

其实就是一个聚合函数,目的就是为了得到论文中描述的Global Information,而与之相对应的是Local Information。

事实上,Local Information就是每个节点学习到的embedding H,而将所有节点embedding聚合起来就得到了Global Information,聚合方法就是Readout function。

至于Readout function有多复杂,论文也给出了很干脆的解释:

image.png

就是一个加权平均再加上SIGMOD函数,而在实验部分额外做了说明:

While we have found this readout to perform the best across all our experiments, we assume that 
its power will diminish with the increase in graph size, and in those cases, more sophisticated
readout architectures such as set2vec (Vinyals et al., 2015) or DiffPool (Ying et al., 2018b) 
are likely to be more appropriate

虽然我们发现这个Readout function在我们所有的实验中表现最好,但我们假设它的能力会随着图大小的增加而减弱,在这些情况下,更复杂的Readout function,如 set2vec (Vinyals et al., 2015) 或 DiffPool (Ying et al., 2018b) 可能更合适。

判别器

判别器其实是一个双线性层,引用的也是前人的工作,可以将两个向量映射成一个分数,具体如下:

image.png

其中W是要学习的权重矩阵。

优化策略

至此,模型所有的准备工作已经完成,就可以开始训练了。

  • 首先,通过干扰函数C生成负样本,(X',A)=C(X,A)
  • 通过图编码器学习正负样本的embedding层H与H',H=GCN(X,A)、H'=GCN(X',A)
  • 再由Readout function获取Global Representation,s=R(H)
  • 将s分别与H和H'输入判别器,得到逻辑概率分数logits1和logits2,代表正负样本中节点分配到s的概率,logits1=D(H,s)、logits2=D(H',s)
  • 由互信息理论可以得到,logits1要接近1,logits2要接近0,因此在这里做二分类,用交叉熵损失做反向传播,优化上述网络中的参数

这样,等到模型收敛后,就可以拿学习到的embedding做下游任务了。

实验部分

论文中提供了node classification实验结果,在得到embedding之后,通过直接使用这些embedding来训练和测试简单的线性(逻辑回归)分类器得到最终分类结果,数据就不在这里展示了,可以移步原文阅读。

代码实现

数据读取

数据集采用的是cora、citeseer和pubmed,读取方式较为固定,这里直接贴出代码

def load_data(dataset):
    """
    读取数据
    :param dataset: 数据集名称
    :return: 邻接权重矩阵、特征矩阵、节点类型、训练集、验证集、测试集
    """
    names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
    objects = []
    for name in names:
        with open('data/{}/ind.{}.{}'.format(dataset, dataset, name), 'rb') as f:
            if sys.version_info > (3, 0):
                objects.append(pkl.load(f, encoding='latin1'))
            else:
                objects.append(pkl.load(f))

    x, y, tx, ty, allx, ally, graph = tuple(objects)
    test_idx_reorder = parse_index_file('data/{}/ind.{}.test.index'.format(dataset, dataset))
    test_idx_range = np.sort(test_idx_reorder)

    if dataset == 'citeseer':
        # Fix citeseer dataset (there are some isolated nodes in the graph)
        # Find isolated nodes, add them as zero-vecs into the right position
        test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder) + 1)
        tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
        tx_extended[test_idx_range - min(test_idx_range), :] = tx
        tx = tx_extended
        ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
        ty_extended[test_idx_range - min(test_idx_range), :] = ty
        ty = ty_extended

    features = sp.vstack((allx, tx)).tolil()
    features[test_idx_reorder, :] = features[test_idx_range, :]
    adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))

    labels = np.vstack((ally, ty))
    labels[test_idx_reorder, :] = labels[test_idx_range, :]

    idx_test = test_idx_range.tolist()
    idx_train = range(len(y))

    return adj, features, labels, idx_train, idx_test

数据读取部分还包括了处理特征、正则化以及稀疏矩阵处理等等步骤,详情请阅读源码,这里不再赘述,由于图表示学习下游任务主要包括node classification以及link prediction,因此数据读取会返回以下参数:

  • X:处理后的特征矩阵
  • A:处理后的邻接矩阵
  • Y:节点标签
  • id_train:训练集节点id
  • id_test:测试集节点id
  • train_edges:训练集边集
  • train_edges_false:训练集边集负例
  • val_edges:验证集边集
  • val_edges_false:验证集边集负例
  • test_edges:测试集边集
  • test_edges_false:测试集边集负例

模型实现

模型采用Pytorch实现,具体代码如下:

import torch.nn as nn


class DGI(nn.Module):
    def __init__(self, n_input, hidden, activation):
        """
        :param n_input: feature size
        :param hidden: gcn hidden layer size
        :param activation: active function
        """
        super(DGI, self).__init__()
        self.gcn = GCN(n_input, hidden, activation)
        self.read = AvgReadout()
        self.sigm = nn.Sigmoid()
        self.disc = Discriminator(hidden)

    def forward(self, pos, neg, adj, sparse):
        """
        :param pos: positive feature
        :param neg: negative feature
        :param adj: adj matrix
        :param sparse: is sparse
        :return: probability
        """
        h_1 = self.gcn(pos, adj, sparse)

        c = self.read(h_1)
        c = self.sigm(c)

        h_2 = self.gcn(neg, adj, sparse)

        ret = self.disc(c, h_1, h_2)
        return ret

    def embed(self, x, adj, sparse):
        """
        when model is trained, get gcn hidden layer output
        :param x: feature
        :param adj: adj matrix
        :param sparse: is sparse
        :return:
        """
        h_1 = self.gcn(x, adj, sparse)

        return h_1.detach()

在forward函数中,h_1表示正样本图的embedding,c表示Readout function聚合得到的Global Representation,h_2表示负样本图的embedding,ret表示过判别器之后正负样本节点的概率分数。

同时,模型提供了embed方法,在测试状态下,可以直接获得图的embedding表示。

GCN层的实现如下:

import torch
import torch.nn as nn


class GCN(nn.Module):
    def __init__(self, in_ft, out_ft, act, bias=True):
        super(GCN, self).__init__()
        self.fc = nn.Linear(in_ft, out_ft, bias=False)
        self.act = nn.PReLU() if act == 'prelu' else act

        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(out_ft))
            self.bias.data.fill_(0.0)
        else:
            self.register_parameter('bias', None)

        for m in self.modules():
            self.weights_init(m)

    @staticmethod
    def weights_init(m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    def forward(self, seq, adj, sparse=False):
        seq_fts = self.fc(seq)
        if sparse:
            out = torch.spmm(adj, torch.squeeze(seq_fts, 0))
        else:
            out = torch.bmm(adj, seq_fts)
        if self.bias is not None:
            out += self.bias

        return self.act(out)

采用的是原始GCN卷积,并没有做改动,Readout function实现如下:

class AvgReadout(nn.Module):
    def __init__(self):
        super(AvgReadout, self).__init__()

    def forward(self, seq):
        return torch.mean(seq, 1)

如之前所描述的一样,采用的是加权平均的方法,判别器实现如下:

class Discriminator(nn.Module):
    def __init__(self, n_h):
        super(Discriminator, self).__init__()
        self.f_k = nn.Bilinear(n_h, n_h, 1)

        for m in self.modules():
            self.weights_init(m)

    @staticmethod
    def weights_init(m):
        if isinstance(m, nn.Bilinear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    def forward(self, c, h_pl, h_mi):
        c_x = torch.unsqueeze(c, 1)
        c_x = c_x.expand_as(h_pl)

        sc_1 = torch.squeeze(self.f_k(h_pl, c_x), 1)
        sc_2 = torch.squeeze(self.f_k(h_mi, c_x), 1)

        logits = torch.cat((sc_1, sc_2), 0)

        return logits

其主要核心在于学习self.f_k这一双线性层参数。

训练方法

训练代码与Pytorch其他模型类型,这里只贴出核心代码:

model = DGI(ft_size, conf.hid_units, conf.non_linearity)
optimiser = torch.optim.Adam(model.parameters(), lr=conf.lr, weight_decay=conf.l2_coef)

b_loss = torch.nn.BCEWithLogitsLoss()

for epoch in range(conf.nb_epochs):
    model.train()
    optimiser.zero_grad()
    idx = np.random.permutation(nb_nodes)

    shuffle_fts = ft[idx, :]
    lbl_1 = torch.ones(nb_nodes)
    lbl_2 = torch.zeros(nb_nodes)
    lbl = torch.cat((lbl_1, lbl_2), 0)

    if torch.cuda.is_available():
        shuffle_fts = shuffle_fts.cuda()
        lbl = lbl.cuda()

    logits = model(ft, shuffle_fts, adjacent, conf.sparse)
    loss = b_loss(logits, lbl)
    print('Epoch: {}, Loss:{}'.format(epoch, loss.item()))

    loss.backward()
    optimiser.step()

代码中干扰函数的具体实现就是idx = np.random.permutation(nb_nodes)这一句话,对节点id做了一次全排列来打乱顺序,得到负样本特征矩阵shuffle_ftslbl_1lbl_2代表理想情况下正样本图节点和负样本节点分配到Global Representation的概率,分别为全1和全0,模型得到的概率分数logits应该靠近这一概率,因此得到二分类交叉熵损失loss

测试方法

如论文中所描述,这里训练了一个逻辑回归分类器来对embedding进行node classification,代码如下:

embeds = model.embed(x, adj, sparse)
for i in range(50):
    log = LogReg(conf.hid_units, nb_classes)
    opt = torch.optim.Adam(log.parameters(), lr=0.01, weight_decay=0.0)

    for _ in range(100):
        log.train()
        opt.zero_grad()

        logits = log(train_embeds)
        loss = x_loss(logits, train_pred)

        loss.backward()
        opt.step()

    logits = log(test_embeds)
    pred = torch.argmax(logits, dim=1)
    acc = torch.sum(pred == test_pred).float() / test_pred.shape[0]
    accuracy.append(acc * 100)
    print('Epoch:{}, Accuracy: {}'.format(i, acc.item()))
    tot += acc

acc = torch.stack(accuracy)
print('Average accuracy: {}, Mean Accuracy: {}, Std Accuracy: {}'.format((tot / 50).item(), acc.mean().item(),
                                                                         acc.std(dim=0).item()))

总结

一篇读起来简单,但是实用性非常强的论文,互信息的应用也对我之后的硕士论文有了很大的启发,还是期望以后能够多到这种价值高的论文。