numpy:用索引数组高效求和

11 投票
3 回答
705 浏览
提问于 2025-04-18 07:47

假设我有两个矩阵 M 和 N(这两个矩阵的列数都大于1)。我还有一个索引矩阵 I,它有两列——一列对应 M,另一列对应 N。N 的索引是唯一的,但 M 的索引可能会出现多次。我想进行的操作是:

for i,j in w:
  M[i] += N[j]

有没有比用 for 循环更有效的方法来完成这个操作呢?

3 个回答

1

为了更清楚,我们先定义一下

>>> m_ind, n_ind = w.T

然后这个 for 循环

for i, j in zip(m_ind, n_ind):
    M[i] += N[j]

会更新条目 M[np.unique(m_ind)]。写入的值是 N[n_ind],这些值必须根据 m_ind 来分组。(其实,n_ind 的存在和 m_ind 无关,简单来说,你可以直接把 N = N[n_ind]。)恰好有一个 SciPy 类可以做到这一点:scipy.sparse.csr_matrix

示例数据:

>>> m_ind, n_ind = array([[0, 0, 1, 1], [2, 3, 0, 1]])
>>> M = np.arange(2, 6)
>>> N = np.logspace(2, 5, 4)

这个 for 循环的结果是 M 变成了 [110002 1103 4 5]。我们用 csr_matrix 也能得到同样的结果。正如我之前说的,n_ind 不重要,所以我们先把它去掉。

>>> N = N[n_ind]
>>> from scipy.sparse import csr_matrix
>>> update = csr_matrix((N, m_ind, [0, len(N)])).toarray()

CSR 构造函数会根据需要的值和索引来构建一个矩阵;它的第三个参数是一个压缩列索引,这意味着值 N[0:len(N)] 的索引是 m_ind[0:len(N)]。重复的值会被相加:

>>> update
array([[ 110000.,    1100.]])

这个矩阵的形状是 (1, len(np.unique(m_ind))),可以直接加进去:

>>> M[np.unique(m_ind)] += update.ravel()
>>> M
array([110002,   1103,      4,      5])
3

这里提到的 m_ind, n_ind = w.T 是在把一个叫做 `w` 的东西进行转置,转置就是把行和列互换。接下来,M += np.bincount(m_ind, weights=N[n_ind], minlength=len(M)) 这行代码的意思是:把 `m_ind` 这个数组中的每个值当作索引,去统计 `N[n_ind]` 中对应位置的权重,最后把结果加到 `M` 里。这里的 `minlength=len(M)` 是为了确保结果的长度和 `M` 一样。

14

为了完整起见,在numpy版本大于等于1.8时,你还可以使用np.addat方法:

In [8]: m, n = np.random.rand(2, 10)

In [9]: m_idx, n_idx = np.random.randint(10, size=(2, 20))

In [10]: m0 = m.copy()

In [11]: np.add.at(m, m_idx, n[n_idx])

In [13]: m0 += np.bincount(m_idx, weights=n[n_idx], minlength=len(m))

In [14]: np.allclose(m, m0)
Out[14]: True

In [15]: %timeit np.add.at(m, m_idx, n[n_idx])
100000 loops, best of 3: 9.49 us per loop

In [16]: %timeit np.bincount(m_idx, weights=n[n_idx], minlength=len(m))
1000000 loops, best of 3: 1.54 us per loop

除了明显的性能劣势,它还有几个优点:

  1. np.bincount会把权重转换为双精度浮点数,而.at方法会使用你数组的原生类型。这使得处理复杂数字时,.at成为最简单的选择。
  2. np.bincount只是把权重加在一起,而at方法可以用于所有的ufuncs(通用函数),所以你可以反复进行multiply(乘法)、logical_and(逻辑与)或者你想做的任何操作。

但对于你的使用场景来说,np.bincount可能是更合适的选择。

撰写回答