PET-AI解读 | rs-fRMI的GNN和TCN建模(模型构建细节)

1,428 阅读4分钟
  • 相关论文:A deep graph neural network architecture for modelling spatio-temporal dynamics in resting-state functional MRI data
  • 相关repo:github.com/tjiagoM/spa…
  • 笔记人:陈亦新

主函数中生成了这样的模型:

model = SpatioTemporalModel(run_cfg=run_cfg,
                                encoding_model=None
                                ).to(run_cfg['device_run'])

这个SpatioTemporalModel非常的长,和以前解读工程一样,我们只看forward函数就行,下面片段中的注释为我的理解:

class SpatioTemporalModel(nn.Module):
    def forward(self, data):
        # 这里的三个数据,和我们在上一小节讲解的一致
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
   
        if self.multimodal_size > 0:
            xn, x = x[:, :self.multimodal_size], x[:, self.multimodal_size:]
            xn = self.multimodal_lin(xn)
            xn = self.activation(xn)
            xn = self.multimodal_batch(xn)
            xn = F.dropout(xn, p=self.dropout, training=self.training)

        # Processing temporal part
        if self.conv_strategy != ConvStrategy.NONE:
            # 这里似乎是吧LSTM也理解为Conv了
            if self.conv_strategy == ConvStrategy.LSTM:
                # 采用LSTM作为特征提取的方法
                x = x.view(-1, self.num_time_length, 1)
                # 可以见下面的LSTM-补充1,就是用0初始化LSTM的隐含特征和cell state
                h0, c0 = self.init_lstm_hidden(x)
                # 可见下面LSTM-补充2,一个LSTM模块
                x, (_, _) = self.temporal_conv(x, (h0, c0))
                x = x.contiguous()
            else:
                # 不是LSTM,那么就是卷积策略了。这里卷积策略包含了一般的1D卷积,也包含了TCN的1D卷积模型。可见下方CNN-补充1和TCN-补充1
                x = x.view(-1, 1, self.num_time_length)
                x = self.temporal_conv(x)

            # Concatenating for the final embedding per node
            # 这个变量self.size_before_lin_temporal的数值,卷积通道x时间序列长度。这时候卷积通道数已经放大了8倍,时间序列长度已经下采样了4次,变成原来的16分之1了。
            x = x.view(x.size()[0], self.size_before_lin_temporal)
            # 是一个全连接层,也可能从_get_lin_temporal函数中得到的组件,详情可以看到下面的方法_get_lin_temporal
            x = self.lin_temporal(x)
            x = self.activation(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        elif self.encoding_strategy == EncodingStrategy.STATS:
        # 全连接层self.stats_lin+1D BN层
            x = self.stats_lin(x)
            x = self.activation(x)
            x = self.stats_batch(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        elif self.encoding_strategy == EncodingStrategy.VAE3layers:
        # 这个也简单,就是VAE自编码器来做的特征提取
            mu, logvar = self.encoder_model.encode(x)
            x = self.encoder_model.reparameterize(mu, logvar)
        elif self.encoding_strategy == EncodingStrategy.AE3layers:
        # 和上面类似,是autoENcoder的
            x = self.encoder_model.encode(x)

        if self.multimodal_size > 0:
            x = torch.cat((xn, x), dim=1)
        # 到这一步的时候,我们的x是已经从ts当中提取好的特征。
        # 图网络用了两个经典中的经典,GAT和GCN。GCN我之前有一篇ISBI的论文用的就是这个,后来就没再看过了。嘎嘎
        if self.sweep_type in [SweepType.GAT, SweepType.GCN]:
        # 总之,图网络的特征提取,其实和transformer的attention map非常类似。这里在宏观讲述模型结构的时候,暂时先不细讲,之后在仔细的考虑TCN和GNN的代码实现细节。
            if self.edge_weights:
                # 这个带上edge-weights的概念,也就是会输入两个节点之间的连接的强弱。
                x = self.gnn_conv1(x, edge_index, edge_weight=edge_attr.view(-1))
            else:
                # 没有edgeweights的概念的,则是,仅仅告诉模型这两个节点有连接有关系,但是并不会进一步的去诉说强弱
                x = self.gnn_conv1(x, edge_index)
            x = self.activation(x)
            x = F.dropout(x, training=self.training)
            # 看来这里的图网络,也是一个非常浅层的,只有1层或者2层的网络。
            if self.num_gnn_layers == 2:
                if self.edge_weights:
                    x = self.gnn_conv2(x, edge_index, edge_weight=edge_attr.view(-1))
                else:
                    x = self.gnn_conv2(x, edge_index)
                x = self.activation(x)
                x = F.dropout(x, training=self.training)
        # 此外,作者还考虑了叫做PNANodeModel的特征提取器
        elif self.sweep_type == SweepType.META_NODE:
            x = self.meta_layer(x, edge_index, edge_attr)
        # 此外,作者还考虑了叫做MetaLayer的特征提取器
        elif self.sweep_type == SweepType.META_EDGE_NODE:
            x, edge_attr, _ = self.meta_layer(x, edge_index, edge_attr)
        # 这里就是和上一章节讲解的graph pool的方式,有平均,相加和DiffPool
        if self.pooling == PoolingStrategy.MEAN:
            x = global_mean_pool(x, data.batch)
        elif self.pooling == PoolingStrategy.ADD:
            x = global_add_pool(x, data.batch)
        elif self.pooling in [PoolingStrategy.DIFFPOOL, PoolingStrategy.DP_MAX, PoolingStrategy.DP_ADD, PoolingStrategy.DP_MEAN, PoolingStrategy.DP_IMPROVED]:
        # 我们还记得上一章遗留了一个问题,就是DiffPool只能处理稠密邻接矩阵,而咱们的是稀疏的。所以转换的方式在这里,可见下面的to_dense_ad部分
            adj_tmp = pyg_utils.to_dense_adj(edge_index, data.batch, edge_attr=edge_attr)
            if edge_attr is not None: # Because edge_attr only has 1 feature per edge
                adj_tmp = adj_tmp[:, :, :, 0]
            x_tmp, batch_mask = pyg_utils.to_dense_batch(x, data.batch)
            # self.diff_pool就是DiffPool这个组件,下一小节继续细讲
            x, link_loss, ent_loss = self.diff_pool(x_tmp, adj_tmp, batch_mask)

            x = F.dropout(x, p=self.dropout, training=self.training)
            x = self.activation(self.pre_final_linear(x))
        elif self.pooling == PoolingStrategy.CONCAT:
            x, _ = to_dense_batch(x, data.batch)
            x = x.view(-1, self.NODE_EMBED_SIZE * self.num_nodes)
            x = self.activation(self.pre_final_linear(x))

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.final_linear(x)

        if self.final_sigmoid:
            return torch.sigmoid(x) if self.pooling not in [PoolingStrategy.DIFFPOOL, PoolingStrategy.DP_MAX, PoolingStrategy.DP_ADD, PoolingStrategy.DP_MEAN, PoolingStrategy.DP_IMPROVED] else (
                torch.sigmoid(x), link_loss, ent_loss)
        else:
            return x if self.pooling not in [PoolingStrategy.DIFFPOOL, PoolingStrategy.DP_MAX, PoolingStrategy.DP_ADD, PoolingStrategy.DP_MEAN, PoolingStrategy.DP_IMPROVED] else (x, link_loss, ent_loss)

对于上述代码段的补充扩展:

  • LSTM-补充1

            def init_lstm_hidden(x):
                h0 = torch.zeros(run_cfg['tcn_depth'], x.size(0), run_cfg['tcn_hidden_units'])
                c0 = torch.zeros(run_cfg['tcn_depth'], x.size(0), run_cfg['tcn_hidden_units'])
                return [t.to(x.device) for t in (h0, c0)]

  • LSTM-补充2
self.temporal_conv = nn.LSTM(input_size=1,
                                         hidden_size=run_cfg['tcn_hidden_units'],
                                         num_layers=run_cfg['tcn_depth'],
                                         dropout=dropout_perc,
                                         batch_first=True)
  • CNN-补充1
stride = 2
            padding = 3
            self.size_before_lin_temporal = self.channels_conv * 8 * self.final_feature_size
            self.lin_temporal = nn.Linear(self.size_before_lin_temporal, self.NODE_EMBED_SIZE - self.multimodal_size)

            self.conv1d_1 = nn.Conv1d(1, self.channels_conv, 7, padding=padding, stride=stride)
            self.conv1d_2 = nn.Conv1d(self.channels_conv, self.channels_conv * 2, 7, padding=padding, stride=stride)
            self.conv1d_3 = nn.Conv1d(self.channels_conv * 2, self.channels_conv * 4, 7, padding=padding, stride=stride)
            self.conv1d_4 = nn.Conv1d(self.channels_conv * 4, self.channels_conv * 8, 7, padding=padding, stride=stride)
            self.batch1 = BatchNorm1d(self.channels_conv)
            self.batch2 = BatchNorm1d(self.channels_conv * 2)
            self.batch3 = BatchNorm1d(self.channels_conv * 4)
            self.batch4 = BatchNorm1d(self.channels_conv * 8)

            self.temporal_conv = nn.Sequential(self.conv1d_1, self.activation, self.batch1, nn.Dropout(dropout_perc),
                                               self.conv1d_2, self.activation, self.batch2, nn.Dropout(dropout_perc),
                                               self.conv1d_3, self.activation, self.batch3, nn.Dropout(dropout_perc),
                                               self.conv1d_4, self.activation, self.batch4, nn.Dropout(dropout_perc))

            self.init_weights()
  • TCN-补充1
#self.size_before_lin_temporal = self.channels_conv * 8 * self.final_feature_size
            #self.lin_temporal = nn.Linear(self.size_before_lin_temporal, self.NODE_EMBED_SIZE - self.multimodal_size)
            if run_cfg['tcn_hidden_units'] == 8:
                self.size_before_lin_temporal = self.channels_conv * (2 ** (run_cfg['tcn_depth'] - 1)) * self.num_time_length
            else:
                self.size_before_lin_temporal = run_cfg['tcn_hidden_units'] * self.num_time_length

            self.lin_temporal = self._get_lin_temporal(run_cfg)

            tcn_layers = []
            for i in range(run_cfg['tcn_depth']):
                if run_cfg['tcn_hidden_units'] == 8:
                    tcn_layers.append(self.channels_conv * (2 ** i) )
                else:
                    tcn_layers.append(run_cfg['tcn_hidden_units'])

            self.temporal_conv = TemporalConvNet(1,
                                                 tcn_layers,
                                                 kernel_size=run_cfg['tcn_kernel'],
                                                 dropout=self.dropout,
                                                 norm_strategy=run_cfg['tcn_norm_strategy'])
  • _get_lin_temporal
def _get_lin_temporal(self, run_cfg):
        if run_cfg['tcn_final_transform_layers'] == 1:
            lin_temporal = nn.Linear(self.size_before_lin_temporal,
                                          self.NODE_EMBED_SIZE - self.multimodal_size)
        elif run_cfg['tcn_final_transform_layers'] == 2:
            lin_temporal = nn.Sequential(
                nn.Linear(self.size_before_lin_temporal, int(self.size_before_lin_temporal / 2)),
                self.activation, nn.Dropout(self.dropout),
                nn.Linear(int(self.size_before_lin_temporal / 2), self.NODE_EMBED_SIZE - self.multimodal_size))
        elif run_cfg['tcn_final_transform_layers'] == 3:
            lin_temporal = nn.Sequential(
                nn.Linear(self.size_before_lin_temporal, int(self.size_before_lin_temporal / 2)),
                self.activation, nn.Dropout(self.dropout),
                nn.Linear(int(self.size_before_lin_temporal / 2), int(self.size_before_lin_temporal / 3)),
                self.activation, nn.Dropout(self.dropout),
                nn.Linear(int(self.size_before_lin_temporal / 3), self.NODE_EMBED_SIZE - self.multimodal_size))

        return lin_temporal
  • to_dense_adj
import torch_geometric.utils as pyg_utils
pyg_utils.to_dense_adj

这个方法的目的是:Converts batched sparse adjacency matrices given by edge indices and edge attributes to a single dense batched adjacency matrix。

官方文档的介绍地址在:torch_geometric.utils.to_dense_adj — pytorch_geometric documentation (pytorch-geometric.readthedocs.io)

综上所述,就是时间序列在这个模型当中经过的全部过程。先是对时间序列进行编码,也就是抽取特征。抽取之后,选择合适的图网络再此进行特征提取。最后使用DiffPool进行特征整合。