比较numbacompiled函数中的字符串

2024-05-23 17:25:30 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在寻找比较使用numbajit(非python模式,python3)编译的python函数中字符串的最佳方法。在

用例如下:

import numba as nb

@nb.jit(nopython = True, cache = True)
def foo(a, t = 'default'):
    if t == 'awesome':
        return(a**2)
    elif t == 'default':
        return(a**3)
    else:
        ...

但是,返回以下错误:

^{pr2}$

我尝试使用字节,但没成功。在

谢谢!在


莫里斯指出了这个问题Python: can numba work with arrays of strings in nopython mode?,但我关注的是原生python,而不是numba支持的numpy子集。在


Tags: 方法函数字符串importtruedefaultreturnas
2条回答

我建议接受@MSeifert的答案,但作为解决此类问题的另一种选择,请考虑使用enum。在

在python中,字符串通常被用作一种枚举,并且numba对枚举有内置的支持,因此可以直接使用它们。在

import enum

class FooOptions(enum.Enum):
    AWESOME = 1
    DEFAULT = 2

import numba

@numba.njit
def foo(a, t=FooOptions.DEFAULT):
    if t == FooOptions.AWESOME:
        return a**2
    elif t == FooOptions.DEFAULT:
        return a**2
    else:
        return a

foo(10, FooOptions.AWESOME)
Out[5]: 100

对于较新的numba版本(0.41.0及更高版本)

Numba(自0.41.0版起)支持^{} in nopython mode,问题中所写的代码将“正常工作”。但是,对于您的例子,比较字符串比您的操作慢得多,因此如果您想在numba函数中使用字符串,请确保开销是值得的。在

import numba as nb

@nb.njit
def foo_string(a, t):
    if t == 'awesome':
        return(a**2)
    elif t == 'default':
        return(a**3)
    else:
        return a

@nb.njit
def foo_int(a, t):
    if t == 1:
        return(a**2)
    elif t == 0:
        return(a**3)
    else:
        return a

assert foo_string(100, 'default') == foo_int(100, 0)
%timeit foo_string(100, 'default')
# 2.82 µs ± 45.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit foo_int(100, 0)
# 213 ns ± 10.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

在您的例子中,使用字符串的代码要慢10倍以上。在

由于您的函数功能不多,因此使用Python而不是numba进行字符串比较会更好更快:

^{pr2}$

这仍然比纯整数版本慢一点,但比在numba函数中使用字符串快10倍。在

但是如果在numba函数中做了大量的数值运算,字符串比较开销就不重要了。但是简单地将numba.njit放在一个函数上,特别是如果它不做很多数组操作或数字运算,它不会自动更快!在

对于旧的numba版本(0.41.0之前的版本):

Numba不支持nopython模式下的字符串。在

documentation

2.6.2. Built-in types

2.6.2.1. int, bool [...]

2.6.2.2. float, complex [...]

2.6.2.3. tuple [...]

2.6.2.4. list [...]

2.6.2.5. set [...]

2.6.2.7. bytes, bytearray, memoryview

The bytearray type and, on Python 3, the bytes type support indexing, iteration and retrieving the len().

[...]

所以字符串根本不受支持,字节也不支持相等性检查。在

但是,您可以传入bytes并对其进行迭代。这样就可以编写自己的比较函数:

import numba as nb

@nb.njit
def bytes_equal(a, b):
    if len(a) != len(b):
        return False
    for char1, char2 in zip(a, b):
        if char1 != char2:
            return False
    return True

不幸的是,下一个问题是numba不能“减少”字节,所以不能直接在函数中硬编码字节。但是字节基本上只是整数,bytes_equal函数适用于numba支持的所有类型,这些类型都有一个长度,可以迭代。因此,您可以简单地将它们存储为列表:

import numba as nb

@nb.njit
def foo(a, t):
    if bytes_equal(t, [97, 119, 101, 115, 111, 109, 101]):
        return a**2
    elif bytes_equal(t, [100, 101, 102, 97, 117, 108, 116]):
        return a**3
    else:
        return a

或者作为全局数组(谢谢@chrisb-见评论):

import numba as nb
import numpy as np

AWESOME = np.frombuffer(b'awesome', dtype='uint8')
DEFAULT = np.frombuffer(b'default', dtype='uint8')

@nb.njit
def foo(a, t):
    if bytes_equal(t, AWESOME):
        return a**2
    elif bytes_equal(t, DEFAULT):
        return a**3
    else:
        return a

两者都能正常工作:

>>> foo(10, b'default')
1000
>>> foo(10, b'awesome')
100
>>> foo(10, b'awe')
10

但是,不能将字节数组指定为默认值,因此需要显式地提供t变量。这样做也让人觉得很不舒服。在

我的观点是:只需在普通函数中执行if t == ...检查,并在ifs中调用专门的numba函数。在Python中字符串比较非常快,只需将数学/数组密集的内容包装在numba函数中:

import numba as nb

@nb.njit
def awesome_func(a):
    return a**2

@nb.njit
def default_func(a):
    return a**3

@nb.njit
def other_func(a):
    return a

def foo(a, t='default'):
    if t == 'awesome':
        return awesome_func(a)
    elif t == 'default':
        return default_func(a)
    else:
        return other_func(a)

但是要确保你真的需要numba来完成这些功能。有时候普通的Python/NumPy就足够快了。只需分析numba解决方案和Python/NumPy解决方案,看看numba是否能使它更快。:)

相关问题 更多 >