我需要一个通用函数f(array, axis, indices)
来指定numpy数组中的任意轴。这里
array
是任意维数的numpy数组axis
是指定数组维度的元组indices
是一个元组,指定上述轴的索引例如,如果我有一个6维数组A
,函数f(A, (0,3,4), (20, 70, 3))
的值将是
A[20, :, :, 70, 3, :]
我怀疑可以通过以下方式使用np.take来实现这一点
def f_take(A, axis, indices):
A1 = A.copy()
# Make sure we iterate over axis in descending order
descAxIdx = np.flip(np.argsort(axis))
descAxis = np.array(axis)[descAxIdx]
descIndices = np.array(indices)[descAxIdx]
for ax, ind in zip(descAxis, descIndices):
A1 = np.take(A1, ind, ax)
return A1
此函数是否已存在于numpy
中?我可以使用我写的f_take
,但是速度对我来说是个问题,所以如果有纯编译的东西(没有python循环),那就太好了
这可以简单地实现为:
相关问题 更多 >
编程相关推荐