在scipy中移除/设置稀疏矩阵的非零对角元素
假设我想从一个 scipy.sparse.csr_matrix
中去掉对角线。有没有什么有效的方法可以做到这一点?我看到在 sparsetools
模块里有一些 C
函数可以用来获取对角线。
根据其他的 StackOverflow 回答 这里 和 这里,我现在的做法是这样的:
def csr_setdiag_val(csr, value=0):
"""Set all diagonal nonzero elements
(elements currently in the sparsity pattern)
to the given value. Useful to set to 0 mostly.
"""
if csr.format != "csr":
raise ValueError('Matrix given must be of CSR format.')
csr.sort_indices()
pointer = csr.indptr
indices = csr.indices
data = csr.data
for i in range(min(csr.shape)):
ind = indices[pointer[i]: pointer[i + 1]]
j = ind.searchsorted(i)
# matrix has only elements up until diagonal (in row i)
if j == len(ind):
continue
j += pointer[i]
# in case matrix has only elements after diagonal (in row i)
if indices[j] == i:
data[j] = value
然后我接着做
csr.eliminate_zeros()
这样做是我能做到的最好方法吗?还是我需要自己写 Cython
代码?
1 个回答
3
根据@hpaulj的评论,我创建了一个IPython Notebook,可以在nbviewer上查看。这个Notebook展示了在提到的所有方法中,以下这个是最快的(假设mat
是一个稀疏的CSR矩阵):
mat - scipy.sparse.dia_matrix((mat.diagonal()[scipy.newaxis, :], [0]), shape=(one_dim, one_dim))