如何在Cython中声明具有通用浮点类型的ndarray

9 投票
2 回答
6592 浏览
提问于 2025-04-18 04:39

在Cython中,声明一个可以处理浮点数和双精度浮点数的numpy数组,最好的方法是什么呢?

我觉得用内存视图可能不行,因为数据类型很重要。但是对于ndarray,有没有办法给它一个通用的浮点类型,这样还能享受到Cython的快速呢?

通常我会这样做:

def F( np.ndarray A):
    A += 10

我还看到有:

def F( np.ndarray[np.float32_t, ndim=2] A):
    A += 10

但这样又会给类型指定一个位数。我也考虑过根据位数(32位或64位)在函数内部创建一个内存视图。

任何想法都非常感谢。


非常感谢你关于floating类型的建议。我试了这样:

import numpy as np
cimport numpy as np
import cython
cimport cython
from libc.math cimport sqrt, abs
from cython cimport floating

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
def Rot_Matrix(np.ndarray[floating, ndim=3] Fit_X,
               np.ndarray[floating, ndim=3] Ref_X,
               weight = None):
    cdef:
        unsigned int t, T = Fit_X.shape[0]
        unsigned int n, N = Fit_X.shape[1]
        np.ndarray[floating, ndim=3] Rot = np.empty((T,3,3))

    return Rot

当我用两个np.float32的数组调用这个函数时,出现了错误:

ValueError: Buffer dtype mismatch, expected 'float' but got 'double'

如果我不在Rot的类型定义中使用类型定义,让它读取np.ndarray[floating, ndim=3] Rot = np.empty((T,3,3)),那么我就能得到ndarray,效果很好。你能告诉我我哪里做错了吗?

2 个回答

2

融合类型只能在函数声明中使用。我能想到的最好的比喻就是C++中的模板。

如果你想创建一个可以同时处理float32和float64的函数,可以这样做:

from cython import floating, numeric
cimport cython

def func_float(floating a):
    print cython.typeof(a)

不过,你只能在已经出现在函数声明中的函数里使用融合类型

13

其实这件事用融合类型的支持来做非常简单:

这段代码放在你的代码里。

from cython cimport floating

def cythonfloating_memoryview(floating[:, :] array):
    cdef int i, j

    for i in range(array.shape[0]):
        for j in range(array.shape[1]):
            array[i, j] += 10

当然,有很多方法可以做到这一点:

把这个文件命名为 fuzed.pyx。不需要编译或运行 cython,这个由 pyximport 来处理。不过,生产环境的代码最好不要用 pyximport,因为通常你应该只发布 .c 文件。

from cython cimport floating
from numpy import float32_t, float64_t, ndarray

ctypedef fused myfloating:
    float32_t
    float64_t

def cythonfloating_memoryview(floating[:, :] array):
    # ...

def cythonfloating_buffer(ndarray[floating, ndim=2] array):
    # ...

def myfloating_memoryview(myfloating[:, :] array):
    # ...

def myfloating_buffer(ndarray[myfloating, ndim=2] array):
    # ...

这里有一个小测试脚本:

把这个文件命名为 test.py,然后像普通的 Python 脚本一样运行它:

import pyximport
pyximport.install()

import fuzed
import numpy

functions = [
    fuzed.cythonfloating_memoryview,
    fuzed.cythonfloating_buffer,
    fuzed.myfloating_memoryview,
    fuzed.myfloating_buffer,
]

for function in functions:
    floats_32 = numpy.zeros((100, 100), "float32")
    doubles_32 = numpy.zeros((100, 100), "float64")

    function(floats_32)
    function(doubles_32)

    print(repr(floats_32))
    print(repr(doubles_32))

值得注意的是,融合类型在编译时是专门处理的,并且在特定的函数调用中是固定不变的。你创建的空 Numpy 数组总是 double 类型,但你可以把它赋值为 32 位浮点数或 64 位浮点数。你应该这样做:

from cython cimport floating
import numpy

def do_some_things(floating[:] input):
    cdef floating[:] output

    if floating is float:
        output = numpy.empty(10, dtype="float32")
    elif floating is double:
        output = numpy.empty(10, dtype="float64")
    else:
        raise ValueError("Unknown floating type.")

    return output

还有一些测试来证明这一点:

import pyximport
pyximport.install()
#>>> (None, None)

import floatingtest
import numpy

floatingtest.do_some_things(numpy.empty(10, dtype="float32"))
#>>> <MemoryView of 'ndarray' at 0x7f0a514b3d88>
floatingtest.do_some_things(numpy.empty(10, dtype="float32")).itemsize
#>>> 4

floatingtest.do_some_things(numpy.empty(10, dtype="float64"))
#>>> <MemoryView of 'ndarray' at 0x7f0a514b3d88>
floatingtest.do_some_things(numpy.empty(10, dtype="float64")).itemsize
#>>> 8

撰写回答