Cython能加速对象数组迭代吗?
我想用cython来加速以下代码:
class A(object):
cdef fun(self):
return 3
class B(object):
cdef fun(self):
return 2
def test():
cdef int x, y, i, s = 0
a = [ [A(), B()], [B(), A()]]
for i in xrange(1000):
for x in xrange(2):
for y in xrange(2):
s += a[x][y].fun()
return s
我想到的解决方案大概是这样的:
def test():
cdef int x, y, i, s = 0
types = [ [0, 1], [1, 0]]
data = [[...], [...]]
for i in xrange(1000):
for x in xrange(2):
for y in xrange(2):
if types[x,y] == 0:
s+= A(data[x,y]).fun()
else:
s+= B(data[x,y]).fun()
return s
基本上,C++中的解决办法是创建一个指向某个基类的指针数组,这个基类里有一个虚方法fun()
,这样你就可以很快地遍历这个数组。请问有没有办法用python/cython来实现这个?
顺便问一下,使用numpy的二维数组(类型为object_)会比使用python列表更快吗?
1 个回答
7
看起来像这样的代码能让速度提高大约20倍:
import numpy as np
cimport numpy as np
cdef class Base(object):
cdef int fun(self):
return -1
cdef class A(Base):
cdef int fun(self):
return 3
cdef class B(Base):
cdef int fun(self):
return 2
def test():
bbb = np.array([[A(), B()], [B(), A()]], dtype=np.object_)
cdef np.ndarray[dtype=object, ndim=2] a = bbb
cdef int i, x, y
cdef int s = 0
cdef Base u
for i in xrange(1000):
for x in xrange(2):
for y in xrange(2):
u = a[x,y]
s += u.fun()
return s
它甚至会检查A和B是否是从Base继承来的,可能在发布版本中有办法关闭这个检查,从而获得额外的速度提升。
编辑:可以通过以下方式去掉这个检查:
u = <Base>a[x,y]