加速NumPy的where函数

5 投票
1 回答
3663 浏览
提问于 2025-04-17 15:22

我想从一个一维数字数组中找出所有超过某个阈值的值的索引。这个数组的长度大约有1e9

我在NumPy中使用的方法是这样的:

idxs = where(data>threshold) 

这个方法花了超过20分钟,这太慢了,不能接受。我该怎么加快这个函数的速度?或者,有没有更快的替代方法?

具体来说,这个速度是在一台运行10.6.7的Mac OS X上,1.86 GHz的Intel处理器,4GB内存,什么都不做的情况下测出来的。

1 个回答

7

试试使用一个掩码数组。这样可以创建一个相同数据的视图。

所以语法应该是:

 b=a[a>threshold]

b 不是一个新数组(和 where 不一样),而是一个视图,显示那些在索引中符合布尔条件的元素。

举个例子:

import numpy as np
import time

a=np.random.random_sample(int(1e9))

t1=time.time()
b=a[a>0.5]
print(time.time()-t1,'seconds')

在我的机器上,这会打印出 22.389815092086792 秒


编辑

我也试了 np.where,速度一样快。我有点怀疑:你是在从数组中删除这些值吗?

撰写回答