在numpy数组中查找第n个最小元素

28 投票
3 回答
43391 浏览
提问于 2025-04-17 23:12

我需要在一个一维的 numpy.array 中找到第 n 小的元素。

举个例子:

a = np.array([90,10,30,40,80,70,20,50,60,0])

我想要得到第 5 小的元素,所以我期望的结果是 40

我现在的解决方案是这样的:

result = np.max(np.partition(a, 5)[:5])

不过,先找出 5 个最小的元素,然后再从中取最大的那个,感觉有点麻烦。有没有更好的方法呢?我是不是漏掉了什么可以直接达到我目标的函数?

虽然有一些类似标题的问题,但我没有找到能回答我问题的内容。

编辑:

我应该一开始就提到,性能对我来说非常重要;所以,虽然 heapq 的解决方案不错,但对我来说不太适用。

import numpy as np
import heapq

def find_nth_smallest_old_way(a, n):
    return np.max(np.partition(a, n)[:n])

# Solution suggested by Jaime and HYRY    
def find_nth_smallest_proper_way(a, n):
    return np.partition(a, n-1)[n-1]

def find_nth_smallest_heapq(a, n):
    return heapq.nsmallest(n, a)[-1]
#    
n_iterations = 10000

a = np.arange(1000)
np.random.shuffle(a)

t1 = timeit('find_nth_smallest_old_way(a, 100)', 'from __main__ import find_nth_smallest_old_way, a', number = n_iterations)
print 'time taken using partition old_way: {}'.format(t1)    
t2 = timeit('find_nth_smallest_proper_way(a, 100)', 'from __main__ import find_nth_smallest_proper_way, a', number = n_iterations)
print 'time taken using partition proper way: {}'.format(t2) 
t3 = timeit('find_nth_smallest_heapq(a, 100)', 'from __main__ import find_nth_smallest_heapq, a', number = n_iterations)  
print 'time taken using heapq : {}'.format(t3)

结果:

time taken using partition old_way: 0.255564928055
time taken using partition proper way: 0.129678010941
time taken using heapq : 7.81094002724

3 个回答

2

你不需要调用 numpy.max() 这个函数:

def nsmall(a, n):
    return np.partition(a, n)[n]
5

你可以使用 heapq.nsmallest 这个功能:

>>> import numpy as np
>>> import heapq
>>> 
>>> a = np.array([90,10,30,40,80,70,20,50,60,0])
>>> heapq.nsmallest(5, a)[-1]
40
43

如果我没有理解错的话,你想做的事情是:

>>> a = np.array([90,10,30,40,80,70,20,50,60,0])
>>> np.partition(a, 4)[4]
40

np.partition(a, k) 这个函数会把数组 a 中第 k+1 小的元素放到 a[k] 的位置,所有比它小的值会放到 a[:k] 里,而比它大的值会放到 a[k+1:] 里。需要注意的是,由于数组是从0开始计数的,所以第五个元素的索引是4。

撰写回答