使用多维索引的Numpy-take

2024-05-14 21:44:34 发布

您现在位置:Python中文网/ 问答频道 /正文

我需要一个通用函数f(array, axis, indices)来指定numpy数组中的任意轴。这里

  • array是任意维数的numpy数组
  • axis是指定数组维度的元组
  • indices是一个元组,指定上述轴的索引

例如,如果我有一个6维数组A,函数f(A, (0,3,4), (20, 70, 3))的值将是

A[20, :, :, 70, 3, :]

我怀疑可以通过以下方式使用np.take来实现这一点

def f_take(A, axis, indices):
    A1 = A.copy()

    # Make sure we iterate over axis in descending order
    descAxIdx = np.flip(np.argsort(axis))
    descAxis = np.array(axis)[descAxIdx]
    descIndices = np.array(indices)[descAxIdx]

    for ax, ind in zip(descAxis, descIndices):
        A1 = np.take(A1, ind, ax)
    return A1

此函数是否已存在于numpy中?我可以使用我写的f_take,但是速度对我来说是个问题,所以如果有纯编译的东西(没有python循环),那就太好了


Tags: 函数innumpya1np数组axarray
1条回答
网友
1楼 · 发布于 2024-05-14 21:44:34

这可以简单地实现为:

import numpy as np

def f(a, axes, indices):
    a = np.asarray(a)
    slices = tuple(indices[axes.index(i)] if i in axes else slice(None)
                   for i in range(a.ndim))
    return a[slices]

相关问题 更多 >

    热门问题