找到非零元素的索引并按值分组
我写了一段Python代码,这段代码可以接收一个numpy矩阵作为输入,然后返回一个根据对应值分组的索引列表(也就是说,output[3]会返回所有值为3的索引)。不过,我对写向量化代码不太了解,所以只能用ndenumerate来实现。这个操作大约花了9秒钟,速度太慢了。
我想到的第二个方法是使用numpy.nonzero,代码如下:
for i in range(1, max_value):
current_array = np.nonzero(input == i)
# save in an array
这个方法花了5.5秒,算是有了不错的提升,但还是有点慢。有没有办法不使用循环,或者有更优化的方法来获取每个值对应的索引对呢?
2 个回答
1
如果你愿意多用一点内存,可以通过广播的方式来进行向量化:
import numpy as np
input = np.random.randint(1,max_value, 100)
indices = np.arange(1, max_value)
matches = input == indices[:,np.newaxis] # broadcasts across each index
这样,对于每个索引 i
,匹配的结果就可以简单地用 np.nonzero(matches[i])
来获取。
3
这里有一个复杂度为O(n log n)的算法来解决你的问题。显而易见,简单的循环方法复杂度是O(n),所以当数据量很大的时候,这种简单的方法会比较慢:
>>> a = np.random.randint(3, size=10)
>>> a
array([1, 2, 2, 0, 1, 0, 2, 2, 1, 1])
>>> index = np.arange(len(a))
>>> sort_idx = np.argsort(a)
>>> cnt = np.bincount(a)
>>> np.split(index[sort_idx], np.cumsum(cnt[:-1]))
[array([3, 5]), array([0, 4, 8, 9]), array([1, 2, 6, 7])]
具体速度还得看你的数据大小,不过对于比较大的数据集来说,这个算法的速度还是相当快的:
In [1]: a = np.random.randint(1000, size=1e6)
In [2]: %%timeit
...: indices = np.arange(len(a))
...: sort_idx = np.argsort(a)
...: cnt = np.bincount(a)
...: np.split(indices[sort_idx], np.cumsum(cnt[:-1]))
...:
10 loops, best of 3: 140 ms per loop