如何在Numba中重现np.repeat()的axis=2(在最后一个维度重复数组)
我正在尝试把我的代码改成适合Numba使用的版本,但我总是遇到关于轴参数的错误(因为它不被支持)。具体来说,我需要在轴=2的情况下使用np.repeat()函数,或者更一般地说,我想知道如何在最后一个维度上重复数组。
在numpy中,我的代码是:
original = np.random.rand(1000,1)
no_repeats = 10
big_original = np.repeat(np.expand_dims((5)*original, axis=2), no_repeats, axis=2)
我该如何以Numba友好的方式重写这个代码呢?
我尝试使用np.dstack:
expanded_original = np.expand_dims((5)*original, axis=2)
big_original = np.dstack([expanded_original]*no_repeats)
但当然,列表不是一个支持的数据类型。我该如何以最有效的方式来实现这个呢?
2 个回答
0
如果输入数据总是二维的,或者你说的轴=2是指“在最后一个维度之后”,那么这个方法是可以用的:
import numpy as np
import numba
original = np.random.rand(1000,1)
no_repeats = 10
big_original = np.repeat(np.expand_dims(original, axis=2), no_repeats, axis=2)
@numba.njit()
def repeatnumba(original,no_repeats):
repeat=original.repeat(no_repeats).reshape(*original.shape,no_repeats )
return repeat
big_numba = repeatnumba(original,no_repeats)
print(np.allclose(big_original, big_numba))
这个方法的结果和你用numpy写的代码是一样的。请注意,这个方法依赖于你期望的结果的最后一个维度是2。如果你的实际维度不同,你可能需要使用 np.transpose
,并提供一个维度的列表,比如:
repeat=np.transpose(repeat,(0,1,-1,2))
如果你的输入是三维的,但你仍然想在第二个维度上重复。
如果你是指“在最后一个维度上”,请考虑更新你的标题和问题描述,这样问题可能会对更多人有帮助。
1
我不太清楚你具体想做什么,但我猜你是想在一个用 @njit
编译的 numba 函数里重现 big_original
这个数组,对吧?
我这样做:
@njit
def repeat_original(original, no_repeats):
big_original = np.zeros((*original.shape, no_repeats))
for i in range(big_original.shape[-1]):
big_original[...,i] = (5)*original
return big_original
repeat_original(original, no_repeats)
如果这不是你期待的答案,请更详细地说明你的问题(比如 expandedGradientMatrix
是什么)以及你希望得到的结果。