Cython能加速对象数组迭代吗?

9 投票
1 回答
3382 浏览
提问于 2025-04-16 05:51

我想用cython来加速以下代码:

class A(object):
    cdef fun(self):
        return 3


class B(object):
    cdef fun(self):
        return 2

def test():
    cdef int x, y, i, s = 0
    a = [ [A(), B()], [B(), A()]]
    for i in xrange(1000):
        for x in xrange(2):
            for y in xrange(2):
                s += a[x][y].fun()
    return s

我想到的解决方案大概是这样的:

def test():
    cdef int x, y, i, s = 0
    types = [ [0, 1], [1, 0]]
    data = [[...], [...]]
    for i in xrange(1000):
        for x in xrange(2):
            for y in xrange(2):
                if types[x,y] == 0:
                   s+= A(data[x,y]).fun()
                else:
                   s+= B(data[x,y]).fun() 
    return s

基本上,C++中的解决办法是创建一个指向某个基类的指针数组,这个基类里有一个虚方法fun(),这样你就可以很快地遍历这个数组。请问有没有办法用python/cython来实现这个?

顺便问一下,使用numpy的二维数组(类型为object_)会比使用python列表更快吗?

1 个回答

7

看起来像这样的代码能让速度提高大约20倍:

import numpy as np
cimport numpy as np
cdef class Base(object):
    cdef int fun(self):
        return -1

cdef class A(Base):
    cdef int fun(self):
        return 3


cdef class B(Base):
    cdef int fun(self):
        return 2

def test():
    bbb = np.array([[A(), B()], [B(), A()]], dtype=np.object_)
    cdef np.ndarray[dtype=object, ndim=2] a = bbb

    cdef int i, x, y
    cdef int s = 0
    cdef Base u

    for i in xrange(1000):
        for x in xrange(2):
            for y in xrange(2):
                u = a[x,y]                
                s += u.fun()
    return s

它甚至会检查A和B是否是从Base继承来的,可能在发布版本中有办法关闭这个检查,从而获得额外的速度提升。

编辑:可以通过以下方式去掉这个检查:

u = <Base>a[x,y]

撰写回答