从scipy-csr稀疏矩阵中选择行子集的最有效方法

2024-04-19 14:59:18 发布

您现在位置:Python中文网/ 问答频道 /正文

给定一个列索引i和一个稀疏的scipy csr矩阵X,我想得到在列i处有1的所有行的平均向量

这是我提出的解决方案,但在我看来相当缓慢:

# a) create a mask of rows containing True if this column was > 0 or False otherwise
mask = (X[:, i] > 0).transpose().toarray()[0]

# b) now get the indices of these rows as list
indices_of_row = list(np.where(mask > 0)[0])
if len(indices_of_row) == 0:
   return

# c) use the indices of these rows to create the mean vector
mean_vector = X[indices_of_row,].mean(axis=0)

有没有办法让它更有效或者更可读

编辑:我希望避免对整个矩阵调用toarray()


Tags: oftheifcreate矩阵maskscipymean
2条回答

我想这就够了

Y = X.toarray()
MeanVec = Y[Y[:,i] > 0].mean(axis=1)

编辑

X.mean(axis=1)[(X.getcol(i) > 0).toarray()]

以下是三种相对快速的解决方案:

from scipy import sparse
import numpy as np


def pp():
    m = np.maximum.reduceat(a.indices==i,a.indptr[:-1])
    cnt = np.count_nonzero(m)
    m = m.repeat(np.diff(a.indptr))
    return np.bincount(a.indices[m],a.data[m],a.shape[1])/cnt

def qq():
    idx = a.indptr.searchsorted(*(a.indices==i).nonzero(),"right")-1
    return np.bincount(
        np.concatenate([a.indices[a.indptr[i]:a.indptr[i+1]] for i in idx]),
        np.concatenate([a.data[a.indptr[i]:a.indptr[i+1]] for i in idx]),
        a.shape[1]) / len(idx)

def mm():
    idx = (a@(np.arange(a.shape[1])==i))!=0
    return idx/np.count_nonzero(idx)@a

def OP():
    # a) create a mask of rows containing True if this column was > 0 or False otherwise
    mask = (a[:, i] > 0).transpose().toarray()[0]

    # b) now get the indices of these rows as list
    indices_of_row = list(np.where(mask > 0)[0])
    if len(indices_of_row) == 0:
        return

    # c) use the indices of these rows to create the mean vector
    return a[indices_of_row,].mean(axis=0)

from timeit import timeit

n = 1000
a = sparse.random(n,n, format="csr")
i = np.random.randint(0,n)

print("mask  ",timeit(pp,number=1000),"ms")
print("concat",timeit(qq,number=1000),"ms")
print("matmul",timeit(mm,number=1000),"ms")
print("OP    ",timeit(OP,number=1000),"ms")

assert np.allclose(pp(),OP())
assert np.allclose(qq(),OP())
assert np.allclose(mm(),OP())

运行示例:

mask   0.08981675305403769 ms
concat 0.04179211403243244 ms
matmul 0.14177833893336356 ms
OP     0.9761617160402238 ms

相关问题 更多 >