Loading

CSR Matrix存储结构

参考:scipy.sparse.csr_matrix — SciPy v1.8.0 Manual

CSR Matrix的存储结构包含三列数据:

  • Index Pointers:表示数据索引的偏置,该列表中每个元素表示“当前行最后一个数据的索引”相对“上一行行最后一个数据的索引”的偏移量(差值)
  • Indices:列表中每个元素对应一个数据的列索引
  • Data:列表中每个元素对应一个数据的值

以上图为例:

  • Index Pointers以0开头,前两个元素(0和2)表示第一行有2个数据。接下来第二和第三个元素(2和3)表示第二行有3-2=1个数据。以此类推,第i和第i+1个元素(j和k)表示稀疏矩阵第i行有k-j个数据
  • Indices中的第一个元素0表示稀疏矩阵中第一个数据的列索引为0,第二个元素2表示稀疏矩阵中的第二个数据的列索引为2。以此类推,第i个元素(r)表示稀疏矩阵中第i个数据的列索引为r
  • Data中的第一个元素8表示稀疏矩阵中第一个数据的值为8,同样第i个元素(v)表示稀疏矩阵中第i个数据的值为v

代码示例:

from scipy.sparse import csr_matrix
import numpy as np

def construct_csr(data):
    indptr = [0]
    col_indeces = []
    values = []
    for row_index, line in enumerate(data):
        row_value_num = 0
        for col_index, value in enumerate(line):
            value = int(value)
            if value > 0:
                row_value_num += 1
                col_indeces.append(col_index)
                values.append(value)
        indptr.append(indptr[-1] + row_value_num)
    row_num = row_index + 1
    col_num = len(line)
    return csr_matrix((values, col_indeces, indptr), shape=(row_num, col_num))

if __name__ == '__main__':
    d_A = [[1,0,3], [0,5,7], [0,0,9], [2,4,0]]
    s_A = csr_matrix(np.array(d_A))
    s_B = construct_csr(d_A)
    print(f's_A:\n{s_A}\n', )
    print(f's_B:\n{s_B}\n', )
    print(s_A.toarray()==s_B.toarray())

执行结果:

s_A:
  (0, 0)        1
  (0, 2)        3
  (1, 1)        5
  (1, 2)        7
  (2, 2)        9
  (3, 0)        2
  (3, 1)        4

s_B:
  (0, 0)        1
  (0, 2)        3
  (1, 1)        5
  (1, 2)        7
  (2, 2)        9
  (3, 0)        2
  (3, 1)        4

[[ True  True  True]
 [ True  True  True]
 [ True  True  True]
 [ True  True  True]]
posted @ 2022-05-03 20:20  云野Winfield  阅读(732)  评论(0编辑  收藏  举报