用Cython优化NumPy

5 投票
2 回答
4337 浏览
提问于 2025-04-16 13:51

我现在正在尝试优化我用纯Python写的代码。这段代码大量使用了NumPy,因为我在处理NumPy数组。下面是我将最简单的一个类转换成Cython的代码。这个类的功能就是对两个NumPy数组进行乘法运算。代码如下:

bendingForces = self.matrixPrefactor * membraneHeight

我想问的是,我该如何优化这段代码,因为当我查看“cython -a”生成的C代码时,发现里面有很多NumPy的调用,看起来效率不高。

import numpy as np
cimport numpy as np
ctypedef np.float64_t dtype_t
ctypedef np.complex128_t cplxtype_t
ctypedef Py_ssize_t index_t

    cdef class bendingForcesClass( object ):
        cdef dtype_t bendingRigidity
        cdef np.ndarray matrixPrefactor
        cdef np.ndarray bendingForces

        def __init__( self, dtype_t bendingRigidity, np.ndarray[dtype_t, ndim=2] waveNumbersNorm ):
            self.bendingRigidity = bendingRigidity
            self.matrixPrefactor = -self.bendingRigidity * waveNumbersNorm**2

        cpdef np.ndarray calculate( self, np.ndarray membraneHeight ) :
            cdef np.ndarray bendingForces
            bendingForces = self.matrixPrefactor * membraneHeight
            return bendingForces

我想到的一个办法是使用两个for循环,逐个遍历数组的元素。也许我可以利用编译器来优化这个过程,使用SIMD操作?我尝试过,虽然可以编译,但结果很奇怪,而且速度非常慢。这是替代函数的代码:

cpdef np.ndarray calculate( self, np.ndarray membraneHeight ) :

    cdef index_t index1, index2 # corresponds to: cdef Py_ssize_t index1, index2
    for index1 in range( self.matrixSize ):
        for index2 in range( self.matrixSize ):
            self.bendingForces[ index1, index2 ] = self.matrixPrefactor.data[ index1, index2 ] * membraneHeight.data[ index1, index2 ]
    return self.bendingForces

不过,正如我所说,这段代码真的很慢,而且结果也不如预期。那么我到底哪里做错了呢?有什么好的方法可以优化这个过程,去掉NumPy的调用操作吗?

2 个回答

0

你可以通过使用

for index1 from 0 <= index1 < max1:

来加快这个速度,而不是使用一个我不确定类型的范围。

你有没有查看过 这个链接这个链接 呢?

9

对于简单的矩阵乘法,NumPy 的代码已经在本地完成了循环和乘法,所以在 Cython 中要超越它是很难的。Cython 很适合用来替换 Python 中的循环,让它们在 Cython 中运行。你代码运行得比 NumPy 慢的一个原因是,每次你在数组中查找索引时,

self.bendingForces[ index1, index2 ] = self.matrixPrefactor.data[ index1, index2 ] * membraneHeight.data[ index1, index2 ]

它会进行更多的计算,比如检查索引是否有效。如果你把索引转换成无符号整数,可以在函数前加上 @cython.boundscheck(False) 这个装饰器来提高效率。

想了解更多关于加速 Cython 代码的细节,可以查看这个 教程

撰写回答