加速NumPy的where函数
我想从一个一维数字数组中找出所有超过某个阈值的值的索引。这个数组的长度大约有1e9
。
我在NumPy
中使用的方法是这样的:
idxs = where(data>threshold)
这个方法花了超过20分钟,这太慢了,不能接受。我该怎么加快这个函数的速度?或者,有没有更快的替代方法?
具体来说,这个速度是在一台运行10.6.7的Mac OS X上,1.86 GHz的Intel处理器,4GB内存,什么都不做的情况下测出来的。
1 个回答
7
试试使用一个掩码数组。这样可以创建一个相同数据的视图。
所以语法应该是:
b=a[a>threshold]
b 不是一个新数组(和 where 不一样),而是一个视图,显示那些在索引中符合布尔条件的元素。
举个例子:
import numpy as np
import time
a=np.random.random_sample(int(1e9))
t1=time.time()
b=a[a>0.5]
print(time.time()-t1,'seconds')
在我的机器上,这会打印出 22.389815092086792 秒
编辑
我也试了 np.where,速度一样快。我有点怀疑:你是在从数组中删除这些值吗?