找到非零元素的索引并按值分组

4 投票
2 回答
2349 浏览
提问于 2025-04-18 03:18

我写了一段Python代码,这段代码可以接收一个numpy矩阵作为输入,然后返回一个根据对应值分组的索引列表(也就是说,output[3]会返回所有值为3的索引)。不过,我对写向量化代码不太了解,所以只能用ndenumerate来实现。这个操作大约花了9秒钟,速度太慢了。

我想到的第二个方法是使用numpy.nonzero,代码如下:

for i in range(1, max_value):
   current_array = np.nonzero(input == i)
   # save in an array

这个方法花了5.5秒,算是有了不错的提升,但还是有点慢。有没有办法不使用循环,或者有更优化的方法来获取每个值对应的索引对呢?

2 个回答

1

如果你愿意多用一点内存,可以通过广播的方式来进行向量化:

import numpy as np
input = np.random.randint(1,max_value, 100)
indices = np.arange(1, max_value)

matches = input == indices[:,np.newaxis]  # broadcasts across each index

这样,对于每个索引 i,匹配的结果就可以简单地用 np.nonzero(matches[i]) 来获取。

3

这里有一个复杂度为O(n log n)的算法来解决你的问题。显而易见,简单的循环方法复杂度是O(n),所以当数据量很大的时候,这种简单的方法会比较慢:

>>> a = np.random.randint(3, size=10)
>>> a
array([1, 2, 2, 0, 1, 0, 2, 2, 1, 1])

>>> index = np.arange(len(a))
>>> sort_idx = np.argsort(a)
>>> cnt = np.bincount(a)
>>> np.split(index[sort_idx], np.cumsum(cnt[:-1]))
[array([3, 5]), array([0, 4, 8, 9]), array([1, 2, 6, 7])]

具体速度还得看你的数据大小,不过对于比较大的数据集来说,这个算法的速度还是相当快的:

In [1]: a = np.random.randint(1000, size=1e6)

In [2]: %%timeit
   ...: indices = np.arange(len(a))
   ...: sort_idx = np.argsort(a)
   ...: cnt = np.bincount(a)
   ...: np.split(indices[sort_idx], np.cumsum(cnt[:-1]))
   ...: 
10 loops, best of 3: 140 ms per loop

撰写回答