获取多维NumPy数组中最大值的位置
我怎么才能找到一个多维NumPy数组中最大值的位置(索引)呢?
4 个回答
2
你可以简单地写一个函数(这个函数只适用于二维情况):
def argmax_2d(matrix):
maxN = np.argmax(matrix)
(xD,yD) = matrix.shape
if maxN >= xD:
x = maxN//xD
y = maxN % xD
else:
y = maxN
x = 0
return (x,y)
7
(编辑)我之前提到的是一个已经被删除的旧答案。而被接受的答案是在我之后的。我同意,argmax
比我的答案更好。
这样做是不是更容易读懂/更直观呢?
numpy.nonzero(a.max() == a)
(array([1]), array([0]))
或者,
numpy.argwhere(a.max() == a)
200
argmax()
方法应该能帮到你。
更新
(看了评论后)我觉得 argmax()
方法也适用于多维数组。链接的文档里有这个的例子:
>>> a = array([[10,50,30],[60,20,40]])
>>> maxindex = a.argmax()
>>> maxindex
3
更新 2
(感谢 KennyTM 的评论)你可以用 unravel_index(a.argmax(), a.shape)
来获取索引,结果会是一个元组:
>>> from numpy import unravel_index
>>> unravel_index(a.argmax(), a.shape)
(1, 0)