Python,使用多处理来进一步加速cython函数

2024-06-13 02:11:30 发布

您现在位置:Python中文网/ 问答频道 /正文

此处显示的代码经过简化,但会触发相同的PicklingError。我知道有很多关于什么可以腌制和什么不能腌制的讨论,但我确实从中找到了解决办法。在

我用以下函数编写了一个简单的cython脚本:

def pow2(int a) : 
    return a**2 

编译工作正常,我可以用python脚本调用这个函数。在

enter image description here

但是,我想知道如何在多处理中使用这个函数

^{pr2}$

给了我一个错误: enter image description here

dtw是包的名称,fast是快速.pyx. 在

我怎样才能避开这个问题? 提前谢谢


Tags: 函数代码脚本名称returndef错误cython
1条回答
网友
1楼 · 发布于 2024-06-13 02:11:30

您可以使用OpenMP包装器prange,而不是使用multiprocessing,这意味着由于酸洗过程而在磁盘上写入数据。在您的情况下,您可以使用如下所示。在

  • 注意使用x*x代替x**2,避免函数调用pow(x, 2)):
  • 使用double指针将数组的一部分传递给每个线程
  • size % num_threads != 0时,最后一个线程接受更多值

代码:

#cython: wraparound=False
#cython: boundscheck=False
#cython: cdivision=True
#cython: nonecheck=False
#cython: profile=False
import numpy as np
cimport numpy as np
from cython.parallel import prange

cdef void cpow2(int size, double *inp, double *out) nogil:
    cdef int i
    for i in range(size):
        out[i] = inp[i]*inp[i]

def pow2(np.ndarray[np.float64_t, ndim=1] inp,
         np.ndarray[np.float64_t, ndim=1] out,
         int num_threads=4):
    cdef int thread
    cdef np.ndarray[np.int32_t, ndim=1] sub_sizes, pos
    size = np.shape(inp)[0]
    sub_sizes = np.zeros(num_threads, np.int32) + size//num_threads
    pos = np.zeros(num_threads, np.int32)
    sub_sizes[num_threads-1] += size % num_threads
    pos[1:] = np.cumsum(sub_sizes)[:num_threads-1]
    for thread in prange(num_threads, nogil=True, chunksize=1,
                         num_threads=num_threads, schedule='static'):
        cpow2(sub_sizes[thread], &inp[pos[thread]], &out[pos[thread]])

def main():
    a = np.arange(642312323).astype(np.float64)
    pow2(a, out=a, num_threads=4)

相关问题 更多 >