pytorch 邻接矩阵转稀疏矩阵 (dense matrix to coo matrix)

774 阅读2分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。


如何将一个dense矩阵,比如:一个 N×NN \times N 的邻接矩阵a 转成 pytorch的sparse coo 矩阵。

方法一:

import numpy as np
import torch

a = np.array([[0, 1.2, 0],[2, 3.1, 0],[0.5, 0, 0]])
idx = a.nonzero() # (row, col)
data = a[idx]

# to torch tensor
idx_t = torch.LongTensor(np.vstack(idx))
data_t = torch.FloatTensor(data)
coo_a = torch.sparse_coo_tensor(idx_t, data_t, a.shape)
print(coo_a)

方法二:

import scipy.sparse as sp
import numpy as np
import torch

a = np.array([[0, 1.2, 0],[2, 3.1, 0],[0.5, 0, 0]])
coo_np = sp.coo_matrix(a)
data = coo_np.data
idx_t = torch.LongTensor(np.vstack((coo_np.row, coo_np.col)))
data_t = torch.FloatTensor(data)
coo_a=torch.sparse_coo_tensor(idx_t,data_t,a.shape)

方法三(从tensor到sparse tensor):

import numpy as np
import torch

a = torch.tensor([[0, 1.2, 0],[2, 3.1, 0],[0.5, 0, 0]])
idx = torch.nonzero(a).T  # 这里需要转置一下
data = a[idx[0],idx[1]]

coo_a = torch.sparse_coo_tensor(idx, data, a.shape)
print(coo_a)

如果想转回dense的tensor

那么可以直接调用自带的方法:

coo_a.to_dense()

什么是COO,CSR矩阵

COO矩阵简介:

    • 也称为“ijv”或“triplet”格式

      • 三个 NumPy 数组:row、col、data
      • data[i]是(row[i], col[i])位置的值
      • 允许重复条目
      • (具有.data属性的_data_matrix稀疏矩阵类 )的子类
  • 构造稀疏矩阵的快速格式
    • 构造函数接受:

      • 密集矩阵(数组)
      • 稀疏矩阵
      • 形状元组(创建空矩阵)
      • (数据, ij)元组
  • 与 CSR/CSC 格式之间的快速转换
  • 快速矩阵*向量(sparsetools)
    • 快速简单的逐项操作

      • 直接操作数据数组(快速 NumPy 机器)
  • 没有切片,没有算术(直接)
    • 利用:

      • 促进稀疏格式之间的快速转换

      • 转换为其他格式(通常是 CSR 或 CSC)时,重复的条目被汇总在一起

        • 促进有限元矩阵的有效构造

CSR矩阵

Compressed Sparse Row. For fast row slicing, faster matrix vector products

    • 面向行

        • 三个 NumPy 数组:indices、indptr、data

          • 索引是列索引的数组
          • data是对应的非零值的数组
          • indptr指向索引和数据中的行开始
          • 长度为n_row + 1,最后一项 = 值数 = 索引和数据的长度
          • 第i行的非零值是data[indptr[i]:indptr[i+1]] 和列索引indices[indptr[i]:indptr[i+1]]
          • item (i, j)可以作为data[indptr[i]+k]访问,其中k是j在indices[indptr[i]:indptr[i+1]]中的位置
        • _cs_matrix(通用 CSR/CSC 功能)的子类

          • (具有.data属性的_data_matrix稀疏矩阵类 )的子类
  • 快速矩阵向量积和其他算术(sparsetools)
    • 构造函数接受:

      • 密集矩阵(数组)
      • 稀疏矩阵
      • 形状元组(创建空矩阵)
      • (数据, ij)元组
      • (数据,索引,indptr)元组
  • 高效的行切片,面向行的操作
  • 缓慢的列切片,对稀疏结构进行昂贵的更改
    • 利用:

      • 实际计算(大多数线性求解器支持这种格式)