在HDF5中存储numpy稀疏矩阵(PyTables)

2024-05-29 03:55:01 发布

您现在位置:Python中文网/ 问答频道 /正文

我在用PyTables存储numpy csr_矩阵时遇到问题。我得到这个错误:

TypeError: objects of type ``csr_matrix`` are not supported in this context, sorry; supported objects are: NumPy array, record or scalar; homogeneous list or tuple, integer, float, complex or string

我的代码:

f = tables.openFile(path,'w')

atom = tables.Atom.from_dtype(self.count_vector.dtype)
ds = f.createCArray(f.root, 'count', atom, self.count_vector.shape)
ds[:] = self.count_vector
f.close()

有什么想法吗?

谢谢


Tags: orselfnumpytablesobjectscountds矩阵
3条回答

DaveP的答案是几乎是对的。。。但对于非常稀疏的矩阵可能会导致问题:如果最后一列或行是空的,则会删除它们。因此,为了确保一切正常,还必须存储“shape”属性。

这是我经常使用的代码:

import tables as tb
from numpy import array
from scipy import sparse

def store_sparse_mat(m, name, store='store.h5'):
    msg = "This code only works for csr matrices"
    assert(m.__class__ == sparse.csr.csr_matrix), msg
    with tb.openFile(store,'a') as f:
        for par in ('data', 'indices', 'indptr', 'shape'):
            full_name = '%s_%s' % (name, par)
            try:
                n = getattr(f.root, full_name)
                n._f_remove()
            except AttributeError:
                pass

            arr = array(getattr(m, par))
            atom = tb.Atom.from_dtype(arr.dtype)
            ds = f.createCArray(f.root, full_name, atom, arr.shape)
            ds[:] = arr

def load_sparse_mat(name, store='store.h5'):
    with tb.openFile(store) as f:
        pars = []
        for par in ('data', 'indices', 'indptr', 'shape'):
            pars.append(getattr(f.root, '%s_%s' % (name, par)).read())
    m = sparse.csr_matrix(tuple(pars[:3]), shape=pars[3])
    return m

将其应用于csc矩阵是很简单的。

我已经为Python 3.6和PyTables3.x更新了Pietro Battiston的优秀答案,因为在从2.x升级的过程中,一些PyTables函数名已经更改

import numpy as np
from scipy import sparse
import tables

def store_sparse_mat(M, name, filename='store.h5'):
    """
    Store a csr matrix in HDF5

    Parameters
    ----------
    M : scipy.sparse.csr.csr_matrix
        sparse matrix to be stored

    name: str
        node prefix in HDF5 hierarchy

    filename: str
        HDF5 filename
    """
    assert(M.__class__ == sparse.csr.csr_matrix), 'M must be a csr matrix'
    with tables.open_file(filename, 'a') as f:
        for attribute in ('data', 'indices', 'indptr', 'shape'):
            full_name = f'{name}_{attribute}'

            # remove existing nodes
            try:  
                n = getattr(f.root, full_name)
                n._f_remove()
            except AttributeError:
                pass

            # add nodes
            arr = np.array(getattr(M, attribute))
            atom = tables.Atom.from_dtype(arr.dtype)
            ds = f.create_carray(f.root, full_name, atom, arr.shape)
            ds[:] = arr

def load_sparse_mat(name, filename='store.h5'):
    """
    Load a csr matrix from HDF5

    Parameters
    ----------
    name: str
        node prefix in HDF5 hierarchy

    filename: str
        HDF5 filename

    Returns
    ----------
    M : scipy.sparse.csr.csr_matrix
        loaded sparse matrix
    """
    with tables.open_file(filename) as f:

        # get nodes
        attributes = []
        for attribute in ('data', 'indices', 'indptr', 'shape'):
            attributes.append(getattr(f.root, f'{name}_{attribute}').read())

    # construct sparse matrix
    M = sparse.csr_matrix(tuple(attributes[:3]), shape=attributes[3])
    return M

CSR矩阵可以从其dataindicesindptr属性完全重构。这些只是常规的numpy数组,因此将它们存储为pytables中的3个独立数组,然后将它们传递回csr_matrix的构造函数应该没有问题。请参阅scipy docs

编辑:Pietro的回答指出,shape成员也应该存储

相关问题 更多 >

    热门问题