如何在Numba中重现np.repeat()的axis=2(在最后一个维度重复数组)

1 投票
2 回答
72 浏览
提问于 2025-04-14 18:26

我正在尝试把我的代码改成适合Numba使用的版本,但我总是遇到关于轴参数的错误(因为它不被支持)。具体来说,我需要在轴=2的情况下使用np.repeat()函数,或者更一般地说,我想知道如何在最后一个维度上重复数组。

在numpy中,我的代码是:

original = np.random.rand(1000,1)
no_repeats = 10
big_original = np.repeat(np.expand_dims((5)*original, axis=2), no_repeats, axis=2)

我该如何以Numba友好的方式重写这个代码呢?

我尝试使用np.dstack:

expanded_original = np.expand_dims((5)*original, axis=2)
big_original = np.dstack([expanded_original]*no_repeats)

但当然,列表不是一个支持的数据类型。我该如何以最有效的方式来实现这个呢?

2 个回答

0

如果输入数据总是二维的,或者你说的轴=2是指“在最后一个维度之后”,那么这个方法是可以用的:

import numpy as np
import numba


original = np.random.rand(1000,1)
no_repeats = 10
big_original = np.repeat(np.expand_dims(original, axis=2), no_repeats, axis=2)


@numba.njit()
def repeatnumba(original,no_repeats):
  repeat=original.repeat(no_repeats).reshape(*original.shape,no_repeats )
  return repeat

big_numba = repeatnumba(original,no_repeats)

print(np.allclose(big_original, big_numba))

这个方法的结果和你用numpy写的代码是一样的。请注意,这个方法依赖于你期望的结果的最后一个维度是2。如果你的实际维度不同,你可能需要使用 np.transpose,并提供一个维度的列表,比如:

repeat=np.transpose(repeat,(0,1,-1,2)) 

如果你的输入是三维的,但你仍然想在第二个维度上重复。

如果你是指“在最后一个维度上”,请考虑更新你的标题和问题描述,这样问题可能会对更多人有帮助。

1

我不太清楚你具体想做什么,但我猜你是想在一个用 @njit 编译的 numba 函数里重现 big_original 这个数组,对吧?

我这样做:

@njit
def repeat_original(original, no_repeats):
    big_original = np.zeros((*original.shape, no_repeats))
    for i in range(big_original.shape[-1]):
        big_original[...,i] = (5)*original
    return big_original
repeat_original(original, no_repeats)

如果这不是你期待的答案,请更详细地说明你的问题(比如 expandedGradientMatrix 是什么)以及你希望得到的结果。

撰写回答