如何在Numba向量化签名中指定元组?

2024-05-23 18:40:44 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在定义一个函数,并希望使用Numba向量化来加速它,使用cuda。我的签名有问题。函数将返回float64值。我想传递两个float64值,它将被矢量化,另外还有一个9元组的float64值,它将是标量。在

这是我的函数头:

from numba import vectorize

@vectorize(['float64(float64, float64, UniTuple(float64, 9))'], target='cuda')
def fn_vec(E, L, fparams):
    # calculations... 
    return result

但这会产生一个错误:

^{pr2}$

我尝试过很多变体,包括(float64,…,float64)来代替UniTuple(),但都无法正常工作。我该怎么做?在


Tags: 函数fromimporttarget定义def矢量化cuda
1条回答
网友
1楼 · 发布于 2024-05-23 18:40:44

How do I specify a tuple in a Numba Vectorize signature?

numba.vectorize函数中,不能使用元组。这是因为vectorize将这些类型的数组的代码矢量化。在

因此,使用float, float, tuple签名可以创建一个函数,该函数需要两个包含浮点的数组和一个包含元组的数组。问题是包含元组的数组没有数据类型——如果使用结构化数组而不是包含元组的数组,它可以工作,但我没有尝试过。在

How do I specify a tuple in a Numba jit signature?

在numba签名中指定UniTuple的正确方法是使用numba.types.containers.UniTuple。在您的情况下:

nb.types.containers.UniTuple(nb.types.float64, 9)

所以正确的签名应该是这样的:

^{pr2}$

我经常避免显式地键入numba函数,但当我这样做时,我发现使用numba.typeof非常有用,例如:

>>> nb.typeof((1.0, ) * 9)
tuple(float64 x 9)

>>> type(nb.typeof((1.0, ) * 9))
numba.types.containers.UniTuple

>>> help(type(nb.typeof((1.0, ) * 9)))  # I shortened the result:
Help on class UniTuple in module numba.types.containers:

class UniTuple(BaseAnonymousTuple, _HomogeneousTuple, numba.types.abstract.Sequence)
 |  UniTuple(*args, **kwargs)
 |  
 |  Type class for homogeneous tuples.
 |  
 |  Methods defined here:
 |  
 |  __init__(self, dtype, count)
 |      Initialize self.  See help(type(self)) for accurate signature.

所以所有的信息都在那里:它是numba.types.containes.UniTuple,你用两个参数实例化它,dtype(这里是float64)和数字(在本例中是9)。在

In case you wanted to vectorize over the float arrays only

如果不想为元组参数对函数进行矢量化,只需在另一个函数中创建一个矢量化函数并在那里调用它:

import numba as nb
import numpy as np

def func(E, L, fparams):
    @nb.vectorize(['float64(float64, float64)'])
    def fn_vec(e, l):
        return e + l + fparams[1]  # just to illustrate that the tuple is available
    return fn_vec(E, L)

这使得元组在vectorized函数中可用。然而,它必须创建内部函数,并在每次调用外部函数时编译它,因此这实际上可能会比较慢。我也不确定这是否能与target="cuda"一起工作,您可能需要自己测试一下。在

相关问题 更多 >