如何获取NumPy数组中N个最大值的索引?

815 投票
21 回答
760953 浏览
提问于 2025-04-16 22:43

NumPy 提供了一种方法,可以通过 np.argmax 来获取数组中最大值的索引。

我想要类似的功能,但希望能返回前 N 个最大值的索引。

举个例子,如果我有一个数组 [1, 3, 2, 4, 5],那么 nargmax(array, n=3) 会返回索引 [4, 3, 1],这些索引对应的元素是 [5, 4, 3]

21 个回答

82

更简单一点:

idx = (-arr).argsort()[:n]

这里的 n 是最大值的数量。

513

我能想到的最简单的方法是:

>>> import numpy as np
>>> arr = np.array([1, 3, 2, 4, 5])
>>> arr.argsort()[-3:][::-1]
array([4, 3, 1])

这个方法需要对整个数组进行排序。我在想,numpy有没有提供一种内置的部分排序的方法;到目前为止,我还没有找到。

如果这个解决方案速度太慢(特别是对于小的 n),那么可以考虑用 Cython 编写一些代码来解决这个问题。

997

较新的NumPy版本(1.8及以上)有一个叫做 argpartition 的函数,可以用来处理这个问题。如果你想找出四个最大的元素的索引,可以这样做:

>>> a = np.array([9, 4, 4, 3, 3, 9, 0, 4, 6, 0])
>>> a
array([9, 4, 4, 3, 3, 9, 0, 4, 6, 0])

>>> ind = np.argpartition(a, -4)[-4:]
>>> ind
array([1, 5, 8, 0])

>>> top4 = a[ind]
>>> top4
array([4, 9, 6, 9])

这个函数和 argsort 不一样,它在最坏情况下的运行时间是线性的,但返回的索引并不是排序好的。你可以通过计算 a[ind] 来看到这一点。如果你也需要这些索引是排序好的,可以在之后再进行排序:

>>> ind[np.argsort(a[ind])]
array([1, 8, 5, 0])

以这种方式获取前 k 个元素并按顺序排列的时间复杂度是 O(n + k log k)。

撰写回答