基于np.d的cython多点版本

2021-09-27 05:08:10 发布

您现在位置:Python中文网/ 问答频道 /正文

我必须用cython优化python代码。 代码中含有大量的向量和矩阵的点积,这是优化过程的关键。你知道吗

产品呈链状,如: A.B.C.D.e.公司 其中e可以是向量,a,B,C,D是矩阵。 我用numpy来表示所有的对象。你知道吗

最初,我使用numpy.linalg.multi_dot。我读到here多点可以比np.dot公司因为它试图优化操作顺序。 然而,在我的例子中,矩阵很小,使用多个点的速度比使用一个链的速度慢10倍np.dot公司. 你知道吗

然而,multi_dot的语法很好,代码中有成千上万的行做不同数量的矩阵和向量的乘积。重新写这些行np.dot公司对于其他开发人员来说,这会降低可读性。你知道吗

基于np.dot公司很简单,但我想把它简单化。事实上,我不知道在编译时作为multi_dot参数给出的矩阵的数量,这迫使我调用python解释程序。 下面是我写的代码:

@cython.boundscheck(False)
@cython.wraparound=False
cpdef mymultidot(np_array_list):
    cdef int n=len(np_array_list)
    R=np_array_list[0]
    for i in range(1,n):
        R=np.dot(R,np_array_list[i])

有没有办法减少这个函数中的python调用?特别是,在执行时不推断np\u数组\u列表的类型?你知道吗

也许可以在编译时推断出np\u数组\u list的类型,比如(np.N阵列[np.浮点数,ndim=2])?但不知道名单的大小将使之变得困难。你知道吗

< P> > Cython有一种方法,有相同的名字,但不是像C/C++一样的参数吗?所以我可以创造

mymultidot( np.ndarray[np.float_t, ndim=2], np.ndarray[np.float_t, ndim=2], np.ndarray[np.float_t, ndim=2])

mymultidot( np.ndarray[np.float_t, ndim=2], np.ndarray[np.float_t, ndim=2])

mymultidot( np.ndarray[np.float_t, ndim=2], np.ndarray[np.float_t, ndim=1])

。。。等等?你知道吗

编写代码中现有的所有版本需要很长时间,但所有多点参数在编译时都是已知的。你知道吗

提前谢谢。你知道吗