CSR Matrix存储结构
参考:scipy.sparse.csr_matrix — SciPy v1.8.0 Manual
CSR Matrix的存储结构包含三列数据:
- Index Pointers:表示数据索引的偏置,该列表中每个元素表示“当前行最后一个数据的索引”相对“上一行行最后一个数据的索引”的偏移量(差值)
- Indices:列表中每个元素对应一个数据的列索引
- Data:列表中每个元素对应一个数据的值
以上图为例:
- Index Pointers以0开头,前两个元素(0和2)表示第一行有2个数据。接下来第二和第三个元素(2和3)表示第二行有3-2=1个数据。以此类推,第i和第i+1个元素(j和k)表示稀疏矩阵第i行有k-j个数据
- Indices中的第一个元素0表示稀疏矩阵中第一个数据的列索引为0,第二个元素2表示稀疏矩阵中的第二个数据的列索引为2。以此类推,第i个元素(r)表示稀疏矩阵中第i个数据的列索引为r
- Data中的第一个元素8表示稀疏矩阵中第一个数据的值为8,同样第i个元素(v)表示稀疏矩阵中第i个数据的值为v
代码示例:
from scipy.sparse import csr_matrix
import numpy as np
def construct_csr(data):
indptr = [0]
col_indeces = []
values = []
for row_index, line in enumerate(data):
row_value_num = 0
for col_index, value in enumerate(line):
value = int(value)
if value > 0:
row_value_num += 1
col_indeces.append(col_index)
values.append(value)
indptr.append(indptr[-1] + row_value_num)
row_num = row_index + 1
col_num = len(line)
return csr_matrix((values, col_indeces, indptr), shape=(row_num, col_num))
if __name__ == '__main__':
d_A = [[1,0,3], [0,5,7], [0,0,9], [2,4,0]]
s_A = csr_matrix(np.array(d_A))
s_B = construct_csr(d_A)
print(f's_A:\n{s_A}\n', )
print(f's_B:\n{s_B}\n', )
print(s_A.toarray()==s_B.toarray())
执行结果:
s_A:
(0, 0) 1
(0, 2) 3
(1, 1) 5
(1, 2) 7
(2, 2) 9
(3, 0) 2
(3, 1) 4
s_B:
(0, 0) 1
(0, 2) 3
(1, 1) 5
(1, 2) 7
(2, 2) 9
(3, 0) 2
(3, 1) 4
[[ True True True]
[ True True True]
[ True True True]
[ True True True]]