使用numpy高效地进行数组阈值过滤
我需要过滤一个数组,去掉那些低于某个阈值的元素。我的代码现在是这样的:
threshold = 5
a = numpy.array(range(10)) # testing data
b = numpy.array(filter(lambda x: x >= threshold, a))
问题是,这样做会创建一个临时列表,使用了一个带有lambda函数的过滤器(速度慢)。
因为这个操作其实很简单,也许有numpy的函数可以更高效地完成这个任务,但我找不到。
我想,另一种方法是先对数组进行排序,找到阈值的索引,然后从那个索引开始返回一个切片,但即使这样对小数据量来说可能会更快(而且其实也不明显),随着输入数据量的增加,这种方法的效率肯定会下降。
更新:我也做了一些测量,当输入有100,000,000个条目时,排序加切片的速度还是比纯Python的过滤器快两倍。
r = numpy.random.uniform(0, 1, 100000000)
%timeit test1(r) # filter
# 1 loops, best of 3: 21.3 s per loop
%timeit test2(r) # sort and slice
# 1 loops, best of 3: 11.1 s per loop
%timeit test3(r) # boolean indexing
# 1 loops, best of 3: 1.26 s per loop
2 个回答
0
你还可以使用 np.where
来找到条件为真的位置,然后进行更复杂的索引操作。
import numpy as np
b = a[np.where(a >= threshold)]
np.where
的一个实用功能是可以用来替换值(比如说,替换那些不符合标准的值)。例如,a[a <= 5] = 0
是直接修改了 a
的内容,而 np.where
则会返回一个新数组,形状和原来的相同,但某些值可能会被改变。
a = np.array([3, 7, 2, 6, 1])
b = np.where(a >= 5, a, 0) # array([0, 7, 0, 6, 0])
在性能方面,它也非常出色。
a, threshold = np.random.uniform(0,1,100000000), 0.5
%timeit a[a >= threshold]
# 1.22 s ± 92.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit a[np.where(a >= threshold)]
# 1.34 s ± 258 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
114
b = a[a>threshold]
这样写应该可以
我测试了如下:
import numpy as np, datetime
# array of zeros and ones interleaved
lrg = np.arange(2).reshape((2,-1)).repeat(1000000,-1).flatten()
t0 = datetime.datetime.now()
flt = lrg[lrg==0]
print datetime.datetime.now() - t0
t0 = datetime.datetime.now()
flt = np.array(filter(lambda x:x==0, lrg))
print datetime.datetime.now() - t0
我得到了
$ python test.py
0:00:00.028000
0:00:02.461000
http://docs.scipy.org/doc/numpy/user/basics.indexing.html#boolean-or-mask-index-arrays