鹦鹉和Numba有什么不同?因为我在某些NumPy表达式上没看到任何改进

5 投票
1 回答
1546 浏览
提问于 2025-04-18 07:10

我想知道有没有人能告诉我,parakeet和Numba jit之间有哪些主要区别?我之所以好奇,是因为我在比较Numexpr、Numba和parakeet这几个工具,而对于这个特定的表达式(我原本以为在Numexpr上表现会非常好,因为它在文档中提到过)

所以结果是

enter image description here

我测试的函数(通过timeit - 每个函数至少重复3次,每次10轮)

import numpy as np
import numexpr as ne
from numba import jit as numba_jit
from parakeet import jit as para_jit


def numpy_complex_expr(A, B):
    return(A*B-4.1*A > 2.5*B)

def numexpr_complex_expr(A, B):
    return ne.evaluate('A*B-4.1*A > 2.5*B')

@numba_jit
def numba_complex_expr(A, B):
    return A*B-4.1*A > 2.5*B

@para_jit
def parakeet_complex_expr(A, B):
    return A*B-4.1*A > 2.5*B

如果你想在自己的机器上核对结果,也可以查看这个IPython笔记本

如果有人在想Numba是否安装正确……我觉得是的,它在我之前的基准测试中表现得很正常:

enter image description here

1 个回答

5

在你当前使用的Numba版本中,对于使用@jit函数的ufuncs(通用函数)支持还不完全。不过,你可以使用@vectorize,而且它的速度更快。

import numpy as np
from numba import jit, vectorize
import numexpr as ne

def numpy_complex_expr(A, B):
    return(A*B+4.1*A > 2.5*B)

def numexpr_complex_expr(A, B):
    return ne.evaluate('A*B+4.1*A > 2.5*B')

@jit
def numba_complex_expr(A, B):
    return A*B+4.1*A > 2.5*B

@vectorize(['u1(float64, float64)'])
def numba_vec(A,B):
    return A*B+4.1*A > 2.5*B

n = 1000
A = np.random.rand(n,n)
B = np.random.rand(n,n)

这里是一些时间测试的结果:

%timeit numba_complex_expr(A,B)
1 loops, best of 3: 49.8 ms per loop

%timeit numpy_complex_expr(A,B)
10 loops, best of 3: 43.5 ms per loop

%timeit numexpr_complex_expr(A,B)
100 loops, best of 3: 3.08 ms per loop

%timeit numba_vec(A,B)
100 loops, best of 3: 9.8 ms per loop

如果你想充分利用Numba的功能,那么你需要展开任何向量化的操作。

@jit
def numba_unroll2(A, B):
    C = np.empty(A.shape, dtype=np.uint8)
    for i in xrange(A.shape[0]):
        for j in xrange(A.shape[1]):
            C[i,j] = A[i,j]*B[i,j] + 4.1*A[i,j] > 2.5*B[i,j]

    return C

%timeit numba_unroll2(A,B)
100 loops, best of 3: 5.96 ms per loop

另外要注意的是,如果你把numexpr使用的线程数设置为1,你会发现它的主要速度优势在于它是并行处理的。

ne.set_num_threads(1)
%timeit numexpr_complex_expr(A,B)
100 loops, best of 3: 8.87 ms per loop

默认情况下,numexpr使用ne.detect_number_of_cores()来确定线程数。在我机器上的第一次测试中,它使用了8个线程。

撰写回答