参数包含2d numpy数组的python函数的显式签名

2024-04-26 00:11:05 发布

您现在位置:Python中文网/ 问答频道 /正文

我想使用@jit或{}来加速我的python代码,这里解释:http://nbviewer.ipython.org/gist/harrism/f5707335f40af9463c43

然而,该页上的示例是针对纯python函数的,而我的函数是在一个类中的,并且基于进一步的搜索,似乎为了使用类函数,我必须提供函数的显式签名。在

我以前没有使用过签名,但现在我了解了如何将它们用于具有简单参数的函数。但我一直在为复杂的参数(比如2D数组)编写它们。在

下面是我需要显式签名的函数。 我真的不知道除了@void之外还能写什么。。。在

""" Function: train
    Input parameters:
    #X =  shape: [n_samples, n_features]
    #y = classes corresponding to X , y's shape: [n_samples]
    #H = int, number of boosting rounds
    Returns: None
    Trains the model based on the training data and true classes
    """
    #@autojit
    #@void
    def train(self, X, y, H):
           # function code below
           # do lots of stuff...

编辑

鉴于我的参数类型,我尝试了以下方法:

^{pr2}$

但得到了以下错误:

Traceback (most recent call last):
  File "C:\Users\app\Documents\Python Scripts\gbc_carclassify.py", line 18, in <module>
    import gentleboost_c_class as gbc
  File "C:\Users\app\Documents\Python Scripts\gentleboost_c_class.py", line 20, in <module>
    @jit
  File "C:\Users\app\Anaconda\lib\site-packages\numba\decorators.py", line 272, in jit
    return jit_extension_class(cls, kws, env)
  File "C:\Users\app\Anaconda\lib\site-packages\numba\exttypes\entrypoints.py", line 20, in jit_extension_class
    return jitclass.create_extension(env, py_class, translator_kwargs)
  File "C:\Users\app\Anaconda\lib\site-packages\numba\exttypes\jitclass.py", line 98, in create_extension
    ext_type = typesystem.jit_exttype(py_class)
  File "C:\Users\app\Anaconda\lib\site-packages\numba\typesystem\types.py", line 55, in __call__
    return type.__call__(self, *args)
  File "C:\Users\app\Anaconda\lib\site-packages\numba\exttypes\types\extensiontype.py", line 37, in __init__
    assert isinstance(py_class, type), ("Must be a new-style class "
AssertionError: Must be a new-style class (inherit from 'object')

编辑2

我已经将类的开头改为添加(object),所以现在看起来像这样:

import numba
from numba import jit, autojit, int_, void, float_

@jit
class GentleBoostC(object):
        def __init__(self):
        # init function

        @void(float_[:,:],int_[:],int_)
        def train(self, X, y, H): # this is the function I want to speed up
        # do stuff

但现在我得到了一个错误:

C:\Users\app\Anaconda\lib\site-packages\numba\exttypes\validators.py:74: UserWarning: Constructor for class 'GentleBoostC' has no signature, assuming arguments have type 'object'
  ext_type.py_class.__name__)
Traceback (most recent call last):
  File "C:\Users\app\Documents\Python Scripts\gbc_carclassify.py", line 18, in <module>
    import gentleboost_c_class as gbc
  File "C:\Users\app\Documents\Python Scripts\gentleboost_c_class.py", line 21, in <module>
    class GentleBoostC(object):
  File "C:\Users\app\Anaconda\lib\site-packages\numba\decorators.py", line 272, in jit
    return jit_extension_class(cls, kws, env)
  File "C:\Users\app\Anaconda\lib\site-packages\numba\exttypes\entrypoints.py", line 20, in jit_extension_class
    return jitclass.create_extension(env, py_class, translator_kwargs)
  File "C:\Users\app\Anaconda\lib\site-packages\numba\exttypes\jitclass.py", line 110, in create_extension
    extension_compiler.infer()
  File "C:\Users\app\Anaconda\lib\site-packages\numba\exttypes\compileclass.py", line 112, in infer
    self.type_infer_methods()
  File "C:\Users\app\Anaconda\lib\site-packages\numba\exttypes\compileclass.py", line 145, in type_infer_methods
    self.type_infer_method(method)
  File "C:\Users\app\Anaconda\lib\site-packages\numba\exttypes\compileclass.py", line 121, in type_infer_method
    **self.flags)
  File "C:\Users\app\Anaconda\lib\site-packages\numba\pipeline.py", line 133, in compile2
    func_ast = functions._get_ast(func)
  File "C:\Users\app\Anaconda\lib\site-packages\numba\functions.py", line 89, in _get_ast
    ast.PyCF_ONLY_AST | flags, True)
  File "C:\Users\app\Documents\Python Scripts\gentleboost_c_class.py", line 1
    def train(self, X, y, H):
    ^
IndentationError: unexpected indent

我想我没有缩进错误。。。在将object添加到类之前,我对这个完全相同的代码没有任何问题。在


Tags: inpyapplibpackageslineextensionsite
1条回答
网友
1楼 · 发布于 2024-04-26 00:11:05

使用数据类型上的切片语法来表示数组。所以你的例子可能看起来像:

from numba import void, int_, float_, jit

...

@jit
class YourClass(object):

    ...

    @void(float_[:, :], int_[:], int_)
    def train(self, X, y, H):
         # X is typed as a 2D float array and y as a 1D int array.
         pass

相关问题 更多 >

    热门问题