使用numpy高效地进行数组阈值过滤

86 投票
2 回答
115089 浏览
提问于 2025-04-17 05:34

我需要过滤一个数组,去掉那些低于某个阈值的元素。我的代码现在是这样的:

threshold = 5
a = numpy.array(range(10)) # testing data
b = numpy.array(filter(lambda x: x >= threshold, a))

问题是,这样做会创建一个临时列表,使用了一个带有lambda函数的过滤器(速度慢)。

因为这个操作其实很简单,也许有numpy的函数可以更高效地完成这个任务,但我找不到。

我想,另一种方法是先对数组进行排序,找到阈值的索引,然后从那个索引开始返回一个切片,但即使这样对小数据量来说可能会更快(而且其实也不明显),随着输入数据量的增加,这种方法的效率肯定会下降。

更新:我也做了一些测量,当输入有100,000,000个条目时,排序加切片的速度还是比纯Python的过滤器快两倍。

r = numpy.random.uniform(0, 1, 100000000)

%timeit test1(r) # filter
# 1 loops, best of 3: 21.3 s per loop

%timeit test2(r) # sort and slice
# 1 loops, best of 3: 11.1 s per loop

%timeit test3(r) # boolean indexing
# 1 loops, best of 3: 1.26 s per loop

2 个回答

0

你还可以使用 np.where 来找到条件为真的位置,然后进行更复杂的索引操作。

import numpy as np
b = a[np.where(a >= threshold)]

np.where 的一个实用功能是可以用来替换值(比如说,替换那些不符合标准的值)。例如,a[a <= 5] = 0 是直接修改了 a 的内容,而 np.where 则会返回一个新数组,形状和原来的相同,但某些值可能会被改变。

a = np.array([3, 7, 2, 6, 1])
b = np.where(a >= 5, a, 0)       # array([0, 7, 0, 6, 0])

在性能方面,它也非常出色。

a, threshold = np.random.uniform(0,1,100000000), 0.5

%timeit a[a >= threshold]
# 1.22 s ± 92.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit a[np.where(a >= threshold)]
# 1.34 s ± 258 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
114

b = a[a>threshold] 这样写应该可以

我测试了如下:

import numpy as np, datetime
# array of zeros and ones interleaved
lrg = np.arange(2).reshape((2,-1)).repeat(1000000,-1).flatten()

t0 = datetime.datetime.now()
flt = lrg[lrg==0]
print datetime.datetime.now() - t0

t0 = datetime.datetime.now()
flt = np.array(filter(lambda x:x==0, lrg))
print datetime.datetime.now() - t0

我得到了

$ python test.py
0:00:00.028000
0:00:02.461000

http://docs.scipy.org/doc/numpy/user/basics.indexing.html#boolean-or-mask-index-arrays

撰写回答