如何在多维数组中找到k个最小数的索引?
我想创建一个数组,里面包含另一个数组中最小的k个值的索引:
import heapq
import numpy as np
a= np.array([[1, 3, 5, 2, 3],
[7, 6, 5, 2, 4],
[2, 0, 5, 6, 4]])
[t[0] for t in heapq.nsmallest(2,enumerate(a[1]),lambda(t):t[1])]
===[3, 4]
但是这样做失败了:
[t[0] for t in heapq.nsmallest(2,enumerate(a.all()),lambda(t):t[1])]
Traceback (most recent call last):
File "<pyshell#19>", line 1, in <module>
[t[0] for t in heapq.nsmallest(2,enumerate(a.all()),lambda(t):t[1])]
TypeError: 'numpy.bool_' object is not iterable
2 个回答
2
你的问题出在 a.all()
这部分:
[t[0] for t in heapq.nsmallest(2,enumerate(a.all()),lambda(t):t[1])]
这个方法会检查你数组里所有元素的真假值,也就是说,如果有一个元素是 False
(比如你有一个0),那么结果就是 False
。
如果你的数组不大,相比于 k 的话,你可以用 .argsort
来获取值。这里我会选择每一行中两个最大的元素的位置:
print a.argsort()[:,:2]
array([[0, 3],
[3, 4],
[1, 0]])
如果你想找到全局最小值的位置,先把数组压平:
a.flatten().argsort()[:2]
如果数组非常大,你可以使用 np.argpartition
来获得更好的性能,这个方法只会进行部分排序。
1
你可以使用 numpy.ndenumerate
结合一个堆,或者像David建议的那样进行部分排序:
a = np.array([[1, 3, 5, 2, 3],
[7, 6, 5, 2, 4],
[2, 0, 5, 6, 4]])
heap = [(v, k) for k,v in numpy.ndenumerate(npa)]
heapq.heapify(heap)
heapq.nsmallest(10, heap) # for k = 10
这样你就得到了:
[(0, (2, 1)),
(1, (0, 0)),
(2, (0, 3)),
(2, (1, 3)),
(2, (2, 0)),
(3, (0, 1)),
(3, (0, 4)),
(4, (1, 4)),
(4, (2, 4)),
(5, (0, 2))]