在Numpy中,如何根据元素绝对值的最大值选择元素?

2024-04-25 16:33:31 发布

您现在位置:Python中文网/ 问答频道 /正文

基本上,我想对任意维的Numpy数组和指定任意轴执行以下Python等效操作:

max(array, key=abs)

即,基于最大绝对值选择元素(类似于array.max(axis=axis)仅沿特定轴选择最大值的方式)

例如(absmax是所需的函数):

array = np.array([
    [ 5,  8,  2],
    [-7,  3,  0],
    [-2, -4, -1],
])
absmax(array, axis=0)  # [-7,  8,  2]
absmax(array, axis=1)  # [ 8, -7, -4]

我提出了以下实现,但感觉相当笨拙:

def absmax(a, *, axis):
    dims = list(a.shape)
    dims.pop(axis)
    indices = np.ogrid[tuple(slice(0, d) for d in dims)]
    argmax = np.abs(a).argmax(axis=axis)
    indices.insert((len(a.shape) + axis) % len(a.shape), argmax)
    return a[tuple(indices)]

因此,我想知道是否有更好/更简洁的方法来实现此功能


Tags: keynumpylennpabs数组arraymax
2条回答

在搜索紧凑性时,这里有一个保持昏暗的-

def absmax(a, axis):
    s = np.array(a.shape)
    s[axis] = -1
    return np.take_along_axis(a,np.abs(a).argmax(axis).reshape(s),axis=axis)

样本运行-

In [67]: a
Out[67]: 
array([[ 5,  8,  2],
       [-7,  3,  0],
       [-2, -4, -1]])

In [68]: absmax(a, axis=0)
Out[68]: array([[-7,  8,  2]])

In [69]: absmax(a, axis=1)
Out[69]: 
array([[ 8],
       [-7],
       [-4]])

如果额外的昏暗外观让人感到困扰,请在输出中添加重塑步骤:

out = np.take_along_axis(a,np.abs(a).argmax(axis).reshape(s),axis=axis)
return out.reshape(np.delete(s,axis))

示例在同一输入数组上运行-

In [89]: absmax(a, axis=0)
Out[89]: array([-7,  8,  2])

In [90]: absmax(a, axis=1)
Out[90]: array([ 8, -7, -4])

也许更简单的方法是使用np.take_along_axis()实现接受lambda_max()参数的key函数:

def lambda_max(arr, axis=None, key=None, keepdims=False):
    if callable(key):
        idxs = np.argmax(key(arr), axis)
        if axis is not None:
            idxs = np.expand_dims(idxs, axis)
            result = np.take_along_axis(arr, idxs, axis)
            if not keepdims:
                result = np.squeeze(result, axis=axis)
            return result
        else:
            return arr.flatten()[idxs]
    else:
        return np.amax(arr, axis)

这可以按如下方式使用:

print(lambda_max(array, 0, np.abs))
# [-7  8  2]
print(lambda_max(array, 1, np.abs))
# [ 8 -7 -4]
print(lambda_max(array, None, np.abs))
# 8

相关问题 更多 >