从npy文件加载稀疏数组

10 投票
3 回答
7088 浏览
提问于 2025-04-16 19:11

我正在尝试加载一个之前保存的稀疏数组。保存这个稀疏数组挺简单的,但读取它就麻烦了。使用scipy.load的时候,它返回的是一个零维数组,里面包裹着我的稀疏数组。

import scipy as sp
A = sp.load("my_array"); A
array(<325729x325729 sparse matrix of type '<type 'numpy.int8'>'
with 1497134 stored elements in Compressed Sparse Row format>, dtype=object)

为了得到一个稀疏矩阵,我必须把这个零维数组展开,或者使用sp.asarray(A)。这看起来真是个复杂的做法。Scipy难道不能聪明一点,直接识别出我加载的是一个稀疏数组吗?有没有更好的方法来加载稀疏数组呢?

3 个回答

1

对于大家对 mmwrite 答案的点赞,我很惊讶没有人尝试回答实际的问题。不过既然这个问题又被提起来了,我就来试试。

这段代码复现了提问者的情况:

In [90]: x=sparse.csr_matrix(np.arange(10).reshape(2,5))
In [91]: np.save('save_sparse.npy',x)
In [92]: X=np.load('save_sparse.npy')
In [95]: X
Out[95]: 
array(<2x5 sparse matrix of type '<type 'numpy.int32'>'
    with 9 stored elements in Compressed Sparse Row format>, dtype=object)
In [96]: X[()].A
Out[96]: 
array([[0, 1, 2, 3, 4],
       [5, 6, 7, 8, 9]])

In [93]: X[()].A
Out[93]: 
array([[0, 1, 2, 3, 4],
       [5, 6, 7, 8, 9]])
In [94]: x
Out[94]: 
<2x5 sparse matrix of type '<type 'numpy.int32'>'
    with 9 stored elements in Compressed Sparse Row format

用户 user4713166 提到的 [()] 并不是提取稀疏数组的“难方法”。

np.savenp.load 是专门用来处理 ndarrays(多维数组)的。但是稀疏矩阵并不是这种数组,也不是它的子类(像 np.matrix 那样)。看起来 np.save 会把非数组对象包裹在一个 object dtype array 中,并把它和对象的序列化形式一起保存。

当我尝试保存另一种不能被序列化的对象时,会在以下位置收到错误信息:

403  # We contain Python objects so we cannot write out the data directly.
404  # Instead, we will pickle it out with version 2 of the pickle protocol.

--> 405 pickle.dump(array, fp, protocol=2)

所以对于问题 Scipy 是否足够聪明,能理解它加载了一个稀疏数组吗?,答案是否定的。np.load 并不知道稀疏数组的存在。但是 np.save 在遇到不是数组的东西时会聪明地选择放弃,而 np.load 则会尽力处理它在文件中找到的内容。

至于保存和加载稀疏数组的其他方法,提到过 io.savemat,这是一个与 MATLAB 兼容的方法,我会优先选择这个。不过这个例子也表明你可以使用普通的 Python pickling。如果你需要保存特定的稀疏格式,这可能更好。而如果你能接受 [()] 的提取步骤,np.save 也不错。 :)


https://github.com/scipy/scipy/blob/master/scipy/io/matlab/mio5.py write_sparse - 稀疏数组以 csc 格式保存。它会保存头信息,以及 A.indices.astype('i4'))A.indptr.astype('i4'))A.data.real,还有可选的 A.data.imag


在快速测试中,我发现 np.save/load 可以处理所有稀疏格式,除了 dok,在这种情况下 load 会抱怨缺少 shape。否则我没有在稀疏文件中发现任何特殊的序列化代码。

6

我们可以用()作为索引来提取隐藏在0维数组中的对象:

A = sp.load("my_array")[()]

这看起来有点奇怪,但似乎有效,而且这是一个非常简短的解决方法。

15

在scipy.io里,有两个函数叫做 mmwritemmread,它们可以用来保存和加载稀疏矩阵,格式是Matrix Market。

scipy.io.mmwrite('/tmp/my_array',x)
scipy.io.mmread('/tmp/my_array').tolil()    

mmwritemmread 可能就是你所需要的全部。这两个函数经过了充分的测试,并且使用的是一个大家都熟悉的格式。

不过,接下来介绍的方法可能会稍微快一些:

我们可以把行和列的坐标以及数据保存为一维数组,使用npz格式。

import random
import scipy.sparse as sparse
import scipy.io
import numpy as np

def save_sparse_matrix(filename,x):
    x_coo=x.tocoo()
    row=x_coo.row
    col=x_coo.col
    data=x_coo.data
    shape=x_coo.shape
    np.savez(filename,row=row,col=col,data=data,shape=shape)

def load_sparse_matrix(filename):
    y=np.load(filename)
    z=sparse.coo_matrix((y['data'],(y['row'],y['col'])),shape=y['shape'])
    return z

N=20000
x = sparse.lil_matrix( (N,N) )
for i in xrange(N):
    x[random.randint(0,N-1),random.randint(0,N-1)]=random.randint(1,100)

save_sparse_matrix('/tmp/my_array',x)
load_sparse_matrix('/tmp/my_array.npz').tolil()

这里有一些代码,建议把稀疏矩阵保存为npz文件,这样可能比用mmwrite/mmread更快:

def using_np_savez():    
    save_sparse_matrix('/tmp/my_array',x)
    return load_sparse_matrix('/tmp/my_array.npz').tolil()

def using_mm():
    scipy.io.mmwrite('/tmp/my_array',x)
    return scipy.io.mmread('/tmp/my_array').tolil()    

if __name__=='__main__':
    for func in (using_np_savez,using_mm):
        y=func()
        print(repr(y))
        assert(x.shape==y.shape)
        assert(x.dtype==y.dtype)
        assert(x.__class__==y.__class__)    
        assert(np.allclose(x.todense(),y.todense()))

结果是

% python -mtimeit -s'import test' 'test.using_mm()'
10 loops, best of 3: 380 msec per loop

% python -mtimeit -s'import test' 'test.using_np_savez()'
10 loops, best of 3: 116 msec per loop

撰写回答