为什么Cython比矢量化的NumPy慢?

12 投票
3 回答
3312 浏览
提问于 2025-04-18 10:22

考虑一下下面这段Cython代码:

cimport cython
cimport numpy as np
import numpy as np

@cython.boundscheck(False)
@cython.wraparound(False)
def test_memoryview(double[:] a, double[:] b):
    cdef int i
    for i in range(a.shape[0]):
        a[i] += b[i]

@cython.boundscheck(False)
@cython.wraparound(False)
def test_numpy(np.ndarray[double, ndim=1] a, np.ndarray[double, ndim=1] b):
    cdef int i
    for i in range(a.shape[0]):
        a[i] += b[i]

def test_numpyvec(a, b):
    a += b

def gendata(nb=40000000):
    a = np.random.random(nb)
    b = np.random.random(nb)
    return a, b

在解释器中运行这段代码后(经过几次运行以预热缓存),得到的结果是:

In [14]: %timeit -n 100 test_memoryview(a, b)
100 loops, best of 3: 148 ms per loop

In [15]: %timeit -n 100 test_numpy(a, b)
100 loops, best of 3: 159 ms per loop

In [16]: %timeit -n 100 test_numpyvec(a, b)
100 loops, best of 3: 124 ms per loop

# See answer below :
In [17]: %timeit -n 100 test_raw_pointers(a, b)
100 loops, best of 3: 129 ms per loop

我尝试了不同大小的数据集,发现向量化的NumPy函数总是比编译后的Cython代码运行得快,而我原本以为Cython的性能应该和向量化的NumPy差不多。

我是不是在Cython代码中漏掉了什么优化?NumPy是不是用了一些东西(比如BLAS)来让这些简单的操作运行得更快?我能否提高这段代码的性能?

更新:原始指针版本的性能似乎和NumPy差不多。所以显然使用内存视图或NumPy索引时会有一些额外的开销。

3 个回答

2

一个小改动可以让速度稍微快一点,那就是指定步幅:

def test_memoryview_inorder(double[::1] a, double[::1] b):
    cdef int i
    for i in range(a.shape[0]):
        a[i] += b[i]
3

在我的电脑上,差别没有那么大,但我几乎可以通过改变numpy和内存视图的函数来消除这个差别,像这样:

@cython.boundscheck(False)
@cython.wraparound(False)
def test_memoryview(double[:] a, double[:] b):
    cdef int i, n=a.shape[0]
    for i in range(n):
        a[i] += b[i]

@cython.boundscheck(False)
@cython.wraparound(False)
def test_numpy(np.ndarray[double] a, np.ndarray[double] b):
    cdef int i, n=a.shape[0]
    for i in range(n):
        a[i] += b[i]

然后,当我从Cython编译C输出时,我使用了 -O3-march=native 这两个标志。这似乎表明,时间差异是由于使用了不同的编译器优化。

我使用的是64位版本的MinGW和NumPy 1.8.1。你的结果可能会因为你使用的包版本、硬件、平台和编译器的不同而有所变化。

如果你在使用IPython笔记本的Cython魔法,可以通过将 %%cython 替换为 %%cython -f -c=-O3 -c=-march=native 来强制更新额外的编译器标志。

如果你在为你的cython模块使用标准的setup.py,你可以在创建传递给 distutils.setup 的Extension对象时指定 extra_compile_args 参数。

注意:我在指定NumPy数组的类型时去掉了 ndim=1 这个标志,因为它并不是必要的。这个值默认就是1。

10

另一种选择是使用原始指针(以及全局指令来避免重复写@cython...):

#cython: wraparound=False
#cython: boundscheck=False
#cython: nonecheck=False

#...

cdef ctest_raw_pointers(int n, double *a, double *b):
    cdef int i
    for i in range(n):
        a[i] += b[i]

def test_raw_pointers(np.ndarray[double, ndim=1] a, np.ndarray[double, ndim=1] b):
    ctest_raw_pointers(a.shape[0], &a[0], &b[0])

撰写回答