在scipy中移除/设置稀疏矩阵的非零对角元素

7 投票
1 回答
2396 浏览
提问于 2025-04-17 23:58

假设我想从一个 scipy.sparse.csr_matrix 中去掉对角线。有没有什么有效的方法可以做到这一点?我看到在 sparsetools 模块里有一些 C 函数可以用来获取对角线。

根据其他的 StackOverflow 回答 这里这里,我现在的做法是这样的:

def csr_setdiag_val(csr, value=0):
    """Set all diagonal nonzero elements
    (elements currently in the sparsity pattern)
    to the given value. Useful to set to 0 mostly.
    """
    if csr.format != "csr":
        raise ValueError('Matrix given must be of CSR format.')
    csr.sort_indices()
    pointer = csr.indptr
    indices = csr.indices
    data = csr.data
    for i in range(min(csr.shape)):
        ind = indices[pointer[i]: pointer[i + 1]]
        j =  ind.searchsorted(i)
        # matrix has only elements up until diagonal (in row i)
        if j == len(ind):
            continue
        j += pointer[i]
        # in case matrix has only elements after diagonal (in row i)
        if indices[j] == i:
            data[j] = value

然后我接着做

csr.eliminate_zeros()

这样做是我能做到的最好方法吗?还是我需要自己写 Cython 代码?

1 个回答

3

根据@hpaulj的评论,我创建了一个IPython Notebook,可以在nbviewer上查看。这个Notebook展示了在提到的所有方法中,以下这个是最快的(假设mat是一个稀疏的CSR矩阵):

mat - scipy.sparse.dia_matrix((mat.diagonal()[scipy.newaxis, :], [0]), shape=(one_dim, one_dim))

撰写回答