在numpy数组中查找第n个最小元素
我需要在一个一维的 numpy.array
中找到第 n 小的元素。
举个例子:
a = np.array([90,10,30,40,80,70,20,50,60,0])
我想要得到第 5 小的元素,所以我期望的结果是 40
。
我现在的解决方案是这样的:
result = np.max(np.partition(a, 5)[:5])
不过,先找出 5 个最小的元素,然后再从中取最大的那个,感觉有点麻烦。有没有更好的方法呢?我是不是漏掉了什么可以直接达到我目标的函数?
虽然有一些类似标题的问题,但我没有找到能回答我问题的内容。
编辑:
我应该一开始就提到,性能对我来说非常重要;所以,虽然 heapq
的解决方案不错,但对我来说不太适用。
import numpy as np
import heapq
def find_nth_smallest_old_way(a, n):
return np.max(np.partition(a, n)[:n])
# Solution suggested by Jaime and HYRY
def find_nth_smallest_proper_way(a, n):
return np.partition(a, n-1)[n-1]
def find_nth_smallest_heapq(a, n):
return heapq.nsmallest(n, a)[-1]
#
n_iterations = 10000
a = np.arange(1000)
np.random.shuffle(a)
t1 = timeit('find_nth_smallest_old_way(a, 100)', 'from __main__ import find_nth_smallest_old_way, a', number = n_iterations)
print 'time taken using partition old_way: {}'.format(t1)
t2 = timeit('find_nth_smallest_proper_way(a, 100)', 'from __main__ import find_nth_smallest_proper_way, a', number = n_iterations)
print 'time taken using partition proper way: {}'.format(t2)
t3 = timeit('find_nth_smallest_heapq(a, 100)', 'from __main__ import find_nth_smallest_heapq, a', number = n_iterations)
print 'time taken using heapq : {}'.format(t3)
结果:
time taken using partition old_way: 0.255564928055
time taken using partition proper way: 0.129678010941
time taken using heapq : 7.81094002724
3 个回答
2
你不需要调用 numpy.max()
这个函数:
def nsmall(a, n):
return np.partition(a, n)[n]
5
你可以使用 heapq.nsmallest
这个功能:
>>> import numpy as np
>>> import heapq
>>>
>>> a = np.array([90,10,30,40,80,70,20,50,60,0])
>>> heapq.nsmallest(5, a)[-1]
40
43
如果我没有理解错的话,你想做的事情是:
>>> a = np.array([90,10,30,40,80,70,20,50,60,0])
>>> np.partition(a, 4)[4]
40
np.partition(a, k)
这个函数会把数组 a
中第 k+1
小的元素放到 a[k]
的位置,所有比它小的值会放到 a[:k]
里,而比它大的值会放到 a[k+1:]
里。需要注意的是,由于数组是从0开始计数的,所以第五个元素的索引是4。