如何使用numpy.argmax从三维数组中提取值

0 投票
2 回答
40 浏览
提问于 2025-04-14 16:11

给定一个三维的numpy数组,我们可以用 numpy.argmax 来找出在第一个维度(也就是轴0)上最大值的索引。

那么,我该如何利用 argmax 的结果,从另一个形状相似的数组中提取这些最大值呢?

举个例子,假设我们有以下内容:

import numpy as np

ndarr1 = np.round(np.random.rand(4, 2, 3), 2)
ndarr2 = np.round(np.random.rand(4, 2, 3), 2)
argmax1 = np.argmax(ndarr1, axis=0)

ndarr1
# array([[[0.89, 0.79, 0.64],
#         [0.03, 0.53, 0.1 ]],
#
#        [[0.21, 0.76, 0.99],
#         [0.47, 0.08, 0.48]],
#
#        [[0.67, 0.94, 0.99],
#         [0.98, 0.75, 0.59]],
#
#        [[0.09, 0.96, 0.98],
#         [0.43, 0.98, 0.71]]])

argmax1
# array([[0, 3, 1],
#        [2, 3, 3]], dtype=int64)

ndarr2
# array([[[0.79, 0.72, 0.82],
#         [0.25, 0.7 , 0.56]],
#
#        [[0.46, 0.11, 0.31],
#         [0.55, 0.76, 0.13]],
#
#        [[0.09, 0.23, 0.35],
#         [0.3 , 0.42, 0.06]],
#
#        [[0.24, 0.1 , 0.92],
#         [0.82, 0.52, 0.7 ]]])

我应该调用哪个函数,才能得到通过使用 argmax1ndarr2 中提取出来的数组呢?

# array([[0.79, 0.1 , 0.31],
#        [0.3 , 0.52, 0.7 ]])

需要注意的是,我不能使用 numpy.amax,因为我需要用 argmax 的结果从一个与 ndarr1 形状相同的不同数组中提取值。

2 个回答

0
a, b = np.ogrid[0:2, 0:3] # 2nd and 3rd dimensions of the array
ndarray2[argmax1, a, b]

如果你想要一句话解决问题:

ndarray2[argmax1, *np.ogrid[0:2, 0:3]]

下面是对问题的回答。如果数组的形状是 (A, B, C, D),那么你可以这样写:

ndarray2[argmaz1, *np.ogrid[:B, :C, :D])

(这里的 0: 是多余的,可以不写。对于任何有三个或更多维度的数组都是这样。)

由于 ogrid 的定义有点特殊,如果原始数组只有两个维度,你就可以省略 *。使用单个切片参数的 ogrid 会返回一个数组,而不是一个单独的数组列表。

如果你有一个数组,但不知道它的大小或维度,你就需要手动构建切片:

slices = [slice(0, i) for i in ndarray2.shape[1:]]
grid = np.ogrid[*slices]
if len(slices) == 1:  # make array into singleton list
    grid = [grid]
ndarray2[argmax1, *grid]
2

根据文档,你可以在扩展了argmax1的维度后使用np.take_along_axis

max2 = np.take_along_axis(ndarr2, np.expand_dims(argmax1, axis=0), axis=0)
print(max2)
# [[[0.79 0.1  0.31]
#   [0.3  0.52 0.7 ]]]

如果在创建argmax1时设置了keepdims=True,那么就可以避免手动扩展维度,也就是说:

argmax1 = np.argmax(ndarr1, axis=0, keepdims=True)
max2 = np.take_along_axis(ndarr2, argmax1, axis=0)

撰写回答