如何定义带有ndarray参数和/或输出的numba类?

2024-05-23 18:55:30 发布

您现在位置:Python中文网/ 问答频道 /正文

我试过了:

import numpy as np
import numba

@numba.jit
class foo(object):
    @numba.void(numba.int32)
    def __init__(self, somenum):
        self.somenumarray = np.arange(somenum)

    @numba.jit('f8[:](f8[:])')
    def somemethod1(self, a):
        return self.somenumarray + a

使用@numba.double[:](numba.double[:])方法修饰符会导致错误。在


Tags: importselfnumpyobjectfoodefasnp
1条回答
网友
1楼 · 发布于 2024-05-23 18:55:30

这可以使用numba.FunctionType

import numpy as np
import numba

bar = numba.FunctionType(return_type=numba.f8[:], args=[numba.f8[:]])

@numba.jit
class foo(object):
    @numba.FunctionType(return_type=numba.void, args=[numba.int32])
    def __init__(self, somenum):
        self.somenumarray = np.arange(somenum)

    @bar
    def somemethod1(self, a):
        return self.somenumarray + a

您可以稍后执行以下操作:

^{pr2}$

相关问题 更多 >