如何获取NumPy数组中N个最大值的索引?
NumPy 提供了一种方法,可以通过 np.argmax
来获取数组中最大值的索引。
我想要类似的功能,但希望能返回前 N
个最大值的索引。
举个例子,如果我有一个数组 [1, 3, 2, 4, 5]
,那么 nargmax(array, n=3)
会返回索引 [4, 3, 1]
,这些索引对应的元素是 [5, 4, 3]
。
21 个回答
82
更简单一点:
idx = (-arr).argsort()[:n]
这里的 n 是最大值的数量。
513
我能想到的最简单的方法是:
>>> import numpy as np
>>> arr = np.array([1, 3, 2, 4, 5])
>>> arr.argsort()[-3:][::-1]
array([4, 3, 1])
这个方法需要对整个数组进行排序。我在想,numpy
有没有提供一种内置的部分排序的方法;到目前为止,我还没有找到。
如果这个解决方案速度太慢(特别是对于小的 n
),那么可以考虑用 Cython 编写一些代码来解决这个问题。
997
较新的NumPy版本(1.8及以上)有一个叫做 argpartition
的函数,可以用来处理这个问题。如果你想找出四个最大的元素的索引,可以这样做:
>>> a = np.array([9, 4, 4, 3, 3, 9, 0, 4, 6, 0])
>>> a
array([9, 4, 4, 3, 3, 9, 0, 4, 6, 0])
>>> ind = np.argpartition(a, -4)[-4:]
>>> ind
array([1, 5, 8, 0])
>>> top4 = a[ind]
>>> top4
array([4, 9, 6, 9])
这个函数和 argsort
不一样,它在最坏情况下的运行时间是线性的,但返回的索引并不是排序好的。你可以通过计算 a[ind]
来看到这一点。如果你也需要这些索引是排序好的,可以在之后再进行排序:
>>> ind[np.argsort(a[ind])]
array([1, 8, 5, 0])
以这种方式获取前 k 个元素并按顺序排列的时间复杂度是 O(n + k log k)。