稀疏矩阵的Scipy中numpy where的等价物

4 投票
3 回答
5845 浏览
提问于 2025-04-18 16:06

我在找一个和 numpy.where 类似的东西,可以用在 scipy 提供的稀疏矩阵上(scipy.sparse)。有没有什么方法可以让你像用 if-then-else 语句那样处理这些矩阵呢?

更新
更具体一点:我需要 where 作为一个 if-then-else 的向量化函数,也就是说,在一些任务中,比如说,对于矩阵 A 中每个等于 K 的值,在矩阵 B 中放一个对应的值,否则放在 C 中。你可以用类似 find 的方法来找出满足条件的那些位置的索引,然后再取反找到剩下的那些,但对于稀疏矩阵来说,难道没有更简洁的方法吗?

3 个回答

1

这是我用 np.where 的替代方法,专门针对稀疏矩阵,使用了 find 函数:

def where(mask, val, arr):
    """ Sparse `where` """
    out = arr.copy()
    rows, cols, _ = find(mask)
    for r, c in zip(rows, cols):
        out[r, c] = val
    return out
7

你可以使用 scipy.sparse.find 这个函数(http://docs.scipy.org/doc/scipy-0.9.0/reference/generated/scipy.sparse.find.html)。这个函数会返回稀疏矩阵中所有非负的值。你可以根据某些条件来使用它,比如:

 import scipy.sparse as sp
 A = sp.csr_matrix([[1, 2, 0], [0, 0, 3], [4, 0, 5]])
 B = A > 2 #the condition
 indexes = sp.find(B)
4

这里有一个函数,它的功能是复制 np.where,当 condxy 是大小相同的稀疏矩阵时。

def where1(cond, x):
    # elements of x where cond
    row, col, data = sparse.find(cond) # effectively the coo format
    data = np.ones(data.shape, dtype=x.dtype)
    zs = sparse.coo_matrix((data, (row, col)), shape=cond.shape)
    xx = x.tocsr()[row, col][0]
    zs.data[:] = xx
    zs = zs.tocsr()
    zs.eliminate_zeros()
    return zs

def where2(cond, y):
    # elements of y where not cond
    row, col, data = sparse.find(cond)
    zs = y.copy().tolil() # faster for this than the csr format
    zs[row, col] = 0
    zs = zs.tocsr()
    zs.eliminate_zeros()
    return zs

def where(cond, x, y):
    # like np.where but with sparse matrices
    ws1 = where1(cond, x)
    # ws2 = where1(cond==0, y) # cond==0 is likely to produce a SparseEfficiencyWarning
    ws2 = where2(cond, y)
    ws = ws1 + ws2
    # test against np.where
    w = np.where(cond.A, x.A, y.A)
    assert np.allclose(ws.A, w)
    return ws

m,n, d = 100,90, 0.5
cs = sparse.rand(m,n,d)
xs = sparse.rand(m,n,d)
ys = sparse.rand(m,n,d)
print where(cs, xs, ys).A

即使在弄清楚如何编写 where1 之后,我还需要进一步思考,如何在不产生警告的情况下处理 not 的部分。虽然它没有稠密的 where 那么通用或快速,但它展示了以这种方式构建稀疏矩阵时所涉及的复杂性。

值得注意的是

np.where(cond) == np.nonzero(cond) # see doc

xs.nonzero() == (xs.row, xs.col) # for coo format
sparse.find(xs) == (row, col, data)

带有 x 和 y 的 np.where 相当于:

[xv if c else yv for (c,xv,yv) in zip(condition,x,y)]  # see doc

C语言的代码可能是通过 nditer 来实现的,这个功能上类似于 zip,可以逐个遍历输入和输出的所有元素。如果输出接近稠密(例如 y=2),那么 np.where 的速度会比这个稀疏的替代方案快。

撰写回答