如何在Cython中声明具有通用浮点类型的ndarray
在Cython中,声明一个可以处理浮点数和双精度浮点数的numpy数组,最好的方法是什么呢?
我觉得用内存视图可能不行,因为数据类型很重要。但是对于ndarray,有没有办法给它一个通用的浮点类型,这样还能享受到Cython的快速呢?
通常我会这样做:
def F( np.ndarray A):
A += 10
我还看到有:
def F( np.ndarray[np.float32_t, ndim=2] A):
A += 10
但这样又会给类型指定一个位数。我也考虑过根据位数(32位或64位)在函数内部创建一个内存视图。
任何想法都非常感谢。
非常感谢你关于floating
类型的建议。我试了这样:
import numpy as np
cimport numpy as np
import cython
cimport cython
from libc.math cimport sqrt, abs
from cython cimport floating
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
def Rot_Matrix(np.ndarray[floating, ndim=3] Fit_X,
np.ndarray[floating, ndim=3] Ref_X,
weight = None):
cdef:
unsigned int t, T = Fit_X.shape[0]
unsigned int n, N = Fit_X.shape[1]
np.ndarray[floating, ndim=3] Rot = np.empty((T,3,3))
return Rot
当我用两个np.float32的数组调用这个函数时,出现了错误:
ValueError: Buffer dtype mismatch, expected 'float' but got 'double'
如果我不在Rot
的类型定义中使用类型定义,让它读取np.ndarray[floating, ndim=3] Rot = np.empty((T,3,3))
,那么我就能得到ndarray,效果很好。你能告诉我我哪里做错了吗?
2 个回答
融合类型只能在函数声明中使用。我能想到的最好的比喻就是C++中的模板。
如果你想创建一个可以同时处理float32和float64的函数,可以这样做:
from cython import floating, numeric
cimport cython
def func_float(floating a):
print cython.typeof(a)
不过,你只能在已经出现在函数声明中的函数里使用融合类型
。
其实这件事用融合类型的支持来做非常简单:
这段代码放在你的代码里。
from cython cimport floating
def cythonfloating_memoryview(floating[:, :] array):
cdef int i, j
for i in range(array.shape[0]):
for j in range(array.shape[1]):
array[i, j] += 10
当然,有很多方法可以做到这一点:
把这个文件命名为 fuzed.pyx。不需要编译或运行
cython
,这个由pyximport
来处理。不过,生产环境的代码最好不要用pyximport
,因为通常你应该只发布.c
文件。
from cython cimport floating
from numpy import float32_t, float64_t, ndarray
ctypedef fused myfloating:
float32_t
float64_t
def cythonfloating_memoryview(floating[:, :] array):
# ...
def cythonfloating_buffer(ndarray[floating, ndim=2] array):
# ...
def myfloating_memoryview(myfloating[:, :] array):
# ...
def myfloating_buffer(ndarray[myfloating, ndim=2] array):
# ...
这里有一个小测试脚本:
把这个文件命名为 test.py,然后像普通的 Python 脚本一样运行它:
import pyximport
pyximport.install()
import fuzed
import numpy
functions = [
fuzed.cythonfloating_memoryview,
fuzed.cythonfloating_buffer,
fuzed.myfloating_memoryview,
fuzed.myfloating_buffer,
]
for function in functions:
floats_32 = numpy.zeros((100, 100), "float32")
doubles_32 = numpy.zeros((100, 100), "float64")
function(floats_32)
function(doubles_32)
print(repr(floats_32))
print(repr(doubles_32))
值得注意的是,融合类型在编译时是专门处理的,并且在特定的函数调用中是固定不变的。你创建的空 Numpy 数组总是 double
类型,但你可以把它赋值为 32 位浮点数或 64 位浮点数。你应该这样做:
from cython cimport floating
import numpy
def do_some_things(floating[:] input):
cdef floating[:] output
if floating is float:
output = numpy.empty(10, dtype="float32")
elif floating is double:
output = numpy.empty(10, dtype="float64")
else:
raise ValueError("Unknown floating type.")
return output
还有一些测试来证明这一点:
import pyximport
pyximport.install()
#>>> (None, None)
import floatingtest
import numpy
floatingtest.do_some_things(numpy.empty(10, dtype="float32"))
#>>> <MemoryView of 'ndarray' at 0x7f0a514b3d88>
floatingtest.do_some_things(numpy.empty(10, dtype="float32")).itemsize
#>>> 4
floatingtest.do_some_things(numpy.empty(10, dtype="float64"))
#>>> <MemoryView of 'ndarray' at 0x7f0a514b3d88>
floatingtest.do_some_things(numpy.empty(10, dtype="float64")).itemsize
#>>> 8