获取多维NumPy数组中最大值的位置

97 投票
4 回答
86717 浏览
提问于 2025-04-16 03:23

我怎么才能找到一个多维NumPy数组中最大值的位置(索引)呢?

4 个回答

2

你可以简单地写一个函数(这个函数只适用于二维情况):

def argmax_2d(matrix):
    maxN = np.argmax(matrix)
    (xD,yD) = matrix.shape
    if maxN >= xD:
        x = maxN//xD
        y = maxN % xD
    else:
        y = maxN
        x = 0
    return (x,y)
7

(编辑)我之前提到的是一个已经被删除的旧答案。而被接受的答案是在我之后的。我同意,argmax比我的答案更好。

这样做是不是更容易读懂/更直观呢?

numpy.nonzero(a.max() == a)
(array([1]), array([0]))

或者,

numpy.argwhere(a.max() == a)
200

argmax() 方法应该能帮到你。

更新

(看了评论后)我觉得 argmax() 方法也适用于多维数组。链接的文档里有这个的例子:

>>> a = array([[10,50,30],[60,20,40]])
>>> maxindex = a.argmax()
>>> maxindex
3

更新 2

(感谢 KennyTM 的评论)你可以用 unravel_index(a.argmax(), a.shape) 来获取索引,结果会是一个元组:

>>> from numpy import unravel_index
>>> unravel_index(a.argmax(), a.shape)
(1, 0)

撰写回答