我正在定义一个函数,并希望使用Numba向量化来加速它,使用cuda。我的签名有问题。函数将返回float64值。我想传递两个float64值,它将被矢量化,另外还有一个9元组的float64值,它将是标量。在
这是我的函数头:
from numba import vectorize
@vectorize(['float64(float64, float64, UniTuple(float64, 9))'], target='cuda')
def fn_vec(E, L, fparams):
# calculations...
return result
但这会产生一个错误:
^{pr2}$我尝试过很多变体,包括(float64,…,float64)来代替UniTuple(),但都无法正常工作。我该怎么做?在
在
numba.vectorize
函数中,不能使用元组。这是因为vectorize
将这些类型的数组的代码矢量化。在因此,使用
float, float, tuple
签名可以创建一个函数,该函数需要两个包含浮点的数组和一个包含元组的数组。问题是包含元组的数组没有数据类型——如果使用结构化数组而不是包含元组的数组,它可以工作,但我没有尝试过。在在numba签名中指定
UniTuple
的正确方法是使用numba.types.containers.UniTuple
。在您的情况下:所以正确的签名应该是这样的:
^{pr2}$我经常避免显式地键入numba函数,但当我这样做时,我发现使用
numba.typeof
非常有用,例如:所以所有的信息都在那里:它是
numba.types.containes.UniTuple
,你用两个参数实例化它,dtype
(这里是float64
)和数字(在本例中是9
)。在如果不想为元组参数对函数进行矢量化,只需在另一个函数中创建一个矢量化函数并在那里调用它:
这使得元组在
vectorize
d函数中可用。然而,它必须创建内部函数,并在每次调用外部函数时编译它,因此这实际上可能会比较慢。我也不确定这是否能与target="cuda"
一起工作,您可能需要自己测试一下。在相关问题 更多 >
编程相关推荐