Numpy数组形状未被numba识别为int

2024-06-16 08:57:34 发布

您现在位置:Python中文网/ 问答频道 /正文

我想基于一些数组生成一个numpy矩阵,并使用jit或理想情况下的njit来加速此生成。如果nopython=False(启用nopython后失败),它将继续发送以下2个警告,我无法理解:

:14: NumbaWarning: Compilation is falling back to object mode WITH looplifting enabled because Function "process_stuffs" failed type inference due to: No conversion from array(int32, 2d, C) to array(int64, 2d, A) for 'inp', defined at None

File "", line 23: def process_stuffs(output,inp,route1, route2, zoneidx):

input_pallets, _ = inp.shape
^

During: typing of argument at (23)

File "", line 23: def process_stuffs(output,inp,route1, route2, zoneidx):

input_pallets, _ = inp.shape
^

@jit(nopython=False, :14: NumbaWarning: Compilation is falling back to object mode WITHOUT looplifting enabled because Function "process_stuffs" failed type inference due to: Cannot determine Numba type of <class 'numba.core.dispatcher.LiftedLoop'>

File "", line 25: def process_stuffs(output,inp,route1, route2, zoneidx):

for minute in range(input_pallets):
^

@jit(nopython=False, C:\Anaconda3\envs\dev38\lib\site-packages\numba\core\object_mode_passes.py:151: NumbaWarning: Function "process_stuffs" was compiled in object mode without forceobj=True, but has lifted loops.

虽然函数确实使用复杂类型,但在确定inp数组的长度时,它在非常开始时失败,然后它不想生成循环,尽管我看到了很多示例

我试图通过使用locals指定类型来纠正错误,但正如您所看到的,这没有帮助

这是一个最低限度的工作代码:

zoneidx=Dict.empty(key_type=types.unicode_type,value_type=types.int8)
zoneidx["A"]=np.int8(0)
zoneidx["B"]=np.int8(1)
zoneidx["C"]=np.int8(2)
zoneidx["D"]=np.int8(3)
zoneidx["E"]=np.int8(4)


output = np.zeros(shape=(110,5),dtype=np.int64)
inp = np.random.randint(0,2,size=(100,2))
route1 = np.random.choice(list('ABCDE'),size=10)
route2 = np.random.choice(list('ABCDE'),size=10)

@jit(nopython=False,
     locals={'input_pallets':numba.int64,
             'step':numba.int64,
             'inp':numba.types.int64[:,:],
             'route1':numba.types.unicode_type[:],
             'route2':numba.types.unicode_type[:],
             'output':numba.types.int64[:,:]})
def process_stuffs(output,inp,route1, route2, zoneidx):

    input_pallets, _ = inp.shape

    for minute in range(input_pallets):
        prod1, prod2 = inp[minute]
        if prod1+prod2 <1:
            continue

        if prod1:
            routing = route1
            number_of_pallets = prod1
            number_of_steps = route1.shape[0]
        else:
            routing = route2
            number_of_pallets = prod2
            number_of_steps = route2.shape[0]
        for step in range(number_of_steps):
            zone = routing[step]
            output[minute+step,zoneidx[zone]]+=number_of_pallets

    return output




numba.__version__ == 0.53.1
numpy.__version__ == 1.19.2

我的代码有什么问题

注意:我对代码输出的正确性不感兴趣,我知道一旦“路由1”被“inp”激活,“路由2”就会被忽略。我只是想把它编译一下


Tags: ofinputoutputtypenpprocesspalletsinp
1条回答
网友
1楼 · 发布于 2024-06-16 08:57:34

警告信息具有误导性。事实上,输入的类型确实没有正确给出,这与.shape方法无关

我的解决方案是使用numba.typeof函数告诉它应该使用什么类型。例如,“inp”应为int32,而不是64。“unichr”是预期的,而不是unicode

以下是我的示例的工作版本:

zoneidx=Dict.empty(key_type=numba.typeof(route1).dtype,value_type=types.int8)
zoneidx["A"]=np.int8(0)
zoneidx["B"]=np.int8(1)
zoneidx["C"]=np.int8(2)
zoneidx["D"]=np.int8(3)
zoneidx["E"]=np.int8(4)


output = np.zeros(shape=(110,5),dtype=np.int64)
inp = np.random.randint(0,2,size=(100,2))
route1 = np.random.choice(list('ABCDE'),size=10)
route2 = np.random.choice(list('ABCDE'),size=10)

@jit(nopython=False,
     locals={'input_pallets':numba.int64,
             'step':numba.int64,
             'inp':numba.types.int32[:,:],
             'route1':numba.typeof(route1),
             'route2':numba.typeof(route1),
             'output':numba.types.int64[:,:]})
def process_stuffs(output,inp,route1, route2, zoneidx):

    input_pallets, _ = inp.shape

    for minute in range(input_pallets):
        prod1, prod2 = inp[minute]
        if prod1+prod2 <1:
            continue

        if prod1:
            routing = route1
            number_of_pallets = prod1
            number_of_steps = route1.shape[0]
        else:
            routing = route2
            number_of_pallets = prod2
            number_of_steps = route2.shape[0]
        for step in range(number_of_steps):
            zone = routing[step]
            output[minute+step,zoneidx[zone]]+=number_of_pallets

    return output

相关问题 更多 >