如何使用numpy.argmax从三维数组中提取值
给定一个三维的numpy数组,我们可以用 numpy.argmax
来找出在第一个维度(也就是轴0)上最大值的索引。
那么,我该如何利用 argmax
的结果,从另一个形状相似的数组中提取这些最大值呢?
举个例子,假设我们有以下内容:
import numpy as np
ndarr1 = np.round(np.random.rand(4, 2, 3), 2)
ndarr2 = np.round(np.random.rand(4, 2, 3), 2)
argmax1 = np.argmax(ndarr1, axis=0)
ndarr1
# array([[[0.89, 0.79, 0.64],
# [0.03, 0.53, 0.1 ]],
#
# [[0.21, 0.76, 0.99],
# [0.47, 0.08, 0.48]],
#
# [[0.67, 0.94, 0.99],
# [0.98, 0.75, 0.59]],
#
# [[0.09, 0.96, 0.98],
# [0.43, 0.98, 0.71]]])
argmax1
# array([[0, 3, 1],
# [2, 3, 3]], dtype=int64)
ndarr2
# array([[[0.79, 0.72, 0.82],
# [0.25, 0.7 , 0.56]],
#
# [[0.46, 0.11, 0.31],
# [0.55, 0.76, 0.13]],
#
# [[0.09, 0.23, 0.35],
# [0.3 , 0.42, 0.06]],
#
# [[0.24, 0.1 , 0.92],
# [0.82, 0.52, 0.7 ]]])
我应该调用哪个函数,才能得到通过使用 argmax1
从 ndarr2
中提取出来的数组呢?
# array([[0.79, 0.1 , 0.31],
# [0.3 , 0.52, 0.7 ]])
需要注意的是,我不能使用 numpy.amax
,因为我需要用 argmax
的结果从一个与 ndarr1
形状相同的不同数组中提取值。
2 个回答
0
a, b = np.ogrid[0:2, 0:3] # 2nd and 3rd dimensions of the array
ndarray2[argmax1, a, b]
如果你想要一句话解决问题:
ndarray2[argmax1, *np.ogrid[0:2, 0:3]]
下面是对问题的回答。如果数组的形状是 (A, B, C, D),那么你可以这样写:
ndarray2[argmaz1, *np.ogrid[:B, :C, :D])
(这里的 0:
是多余的,可以不写。对于任何有三个或更多维度的数组都是这样。)
由于 ogrid
的定义有点特殊,如果原始数组只有两个维度,你就可以省略 *
。使用单个切片参数的 ogrid
会返回一个数组,而不是一个单独的数组列表。
如果你有一个数组,但不知道它的大小或维度,你就需要手动构建切片:
slices = [slice(0, i) for i in ndarray2.shape[1:]]
grid = np.ogrid[*slices]
if len(slices) == 1: # make array into singleton list
grid = [grid]
ndarray2[argmax1, *grid]
2
根据文档,你可以在扩展了argmax1
的维度后使用np.take_along_axis
。
max2 = np.take_along_axis(ndarr2, np.expand_dims(argmax1, axis=0), axis=0)
print(max2)
# [[[0.79 0.1 0.31]
# [0.3 0.52 0.7 ]]]
如果在创建argmax1
时设置了keepdims=True
,那么就可以避免手动扩展维度,也就是说:
argmax1 = np.argmax(ndarr1, axis=0, keepdims=True)
max2 = np.take_along_axis(ndarr2, argmax1, axis=0)