如何在多维数组中找到k个最小数的索引?

1 投票
2 回答
524 浏览
提问于 2025-04-18 10:37

我想创建一个数组,里面包含另一个数组中最小的k个值的索引:

import heapq
import numpy as np

a= np.array([[1, 3, 5, 2, 3],
       [7, 6, 5, 2, 4],
       [2, 0, 5, 6, 4]])

[t[0] for t in heapq.nsmallest(2,enumerate(a[1]),lambda(t):t[1])]
===[3, 4]

但是这样做失败了:

[t[0] for t in heapq.nsmallest(2,enumerate(a.all()),lambda(t):t[1])]

Traceback (most recent call last):
  File "<pyshell#19>", line 1, in <module>
    [t[0] for t in heapq.nsmallest(2,enumerate(a.all()),lambda(t):t[1])]
TypeError: 'numpy.bool_' object is not iterable

2 个回答

2

你的问题出在 a.all() 这部分:

[t[0] for t in heapq.nsmallest(2,enumerate(a.all()),lambda(t):t[1])]

这个方法会检查你数组里所有元素的真假值,也就是说,如果有一个元素是 False(比如你有一个0),那么结果就是 False

如果你的数组不大,相比于 k 的话,你可以用 .argsort 来获取值。这里我会选择每一行中两个最大的元素的位置:

print a.argsort()[:,:2]

array([[0, 3],
       [3, 4],
       [1, 0]])

如果你想找到全局最小值的位置,先把数组压平:

a.flatten().argsort()[:2]

如果数组非常大,你可以使用 np.argpartition 来获得更好的性能,这个方法只会进行部分排序。

1

你可以使用 numpy.ndenumerate 结合一个堆,或者像David建议的那样进行部分排序:

a = np.array([[1, 3, 5, 2, 3],
       [7, 6, 5, 2, 4],
       [2, 0, 5, 6, 4]])
heap = [(v, k) for k,v in numpy.ndenumerate(npa)]
heapq.heapify(heap)
heapq.nsmallest(10, heap) # for k = 10

这样你就得到了:

[(0, (2, 1)),
 (1, (0, 0)),
 (2, (0, 3)),
 (2, (1, 3)),
 (2, (2, 0)),
 (3, (0, 1)),
 (3, (0, 4)),
 (4, (1, 4)),
 (4, (2, 4)),
 (5, (0, 2))]

撰写回答