用C函数扩展Numpy
我正在尝试加速我的Numpy代码,决定在C语言中实现一个特别的函数,因为我的代码大部分时间都花在这里。
其实我在C语言方面还是个新手,但我成功写出了一个函数,可以把矩阵的每一行归一化,使它们的和为1。我可以编译这个函数,并用一些数据(在C语言中)测试,结果也符合我的预期。那时候我非常自豪。
现在我想从Python中调用这个很棒的函数,它应该能接受一个二维的Numpy数组。
我尝试过的几种方法有:
SWIG
SWIG +
numpy.i
ctypes
我的函数的原型是:
void normalize_logspace_matrix(size_t nrow, size_t ncol, double mat[nrow][ncol]);
这个函数接受一个指向可变长度数组的指针,并在原地修改它。
我尝试了以下纯SWIG接口文件:
%module c_utils
%{
extern void normalize_logspace_matrix(size_t, size_t, double mat[*][*]);
%}
extern void normalize_logspace_matrix(size_t, size_t, double** mat);
然后我在Mac OS X 64位系统上执行:
> swig -python c-utils.i
> gcc -fPIC c-utils_wrap.c -o c-utils_wrap.o \
-I/Library/Frameworks/Python.framework/Versions/6.2/include/python2.6/ \
-L/Library/Frameworks/Python.framework/Versions/6.2/lib/python2.6/ -c
c-utils_wrap.c: In function ‘_wrap_normalize_logspace_matrix’:
c-utils_wrap.c:2867: warning: passing argument 3 of ‘normalize_logspace_matrix’ from incompatible pointer type
> g++ -dynamiclib c-utils.o -o _c_utils.so
在Python中导入我的模块时,我遇到了以下错误:
>>> import c_utils
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ImportError: dynamic module does not define init function (initc_utils)
接下来我尝试了使用SWIG + numpy.i
的方法:
%module c_utils
%{
#define SWIG_FILE_WITH_INIT
#include "c-utils.h"
%}
%include "numpy.i"
%init %{
import_array();
%}
%apply ( int DIM1, int DIM2, DATA_TYPE* INPLACE_ARRAY2 )
{(size_t nrow, size_t ncol, double* mat)};
%include "c-utils.h"
但是,我没有更进一步:
> swig -python c-utils.i
c-utils.i:13: Warning 453: Can't apply (int DIM1,int DIM2,DATA_TYPE *INPLACE_ARRAY2). No typemaps are defined.
SWIG似乎找不到在numpy.i
中定义的类型映射,但我不明白为什么,因为numpy.i
在同一个目录下,SWIG也没有抱怨找不到它。
使用ctypes我也没能取得太大进展,很快就被文档搞晕了,因为我不知道怎么传递一个二维数组,然后再获取结果。
所以,有人能告诉我怎么把我的函数在Python/Numpy中使用吗?
5 个回答
首先,你确定你写的numpy代码是最快的吗?如果你说的“归一化”是指把整行数据除以它的总和,那么你可以写出一种快速的向量化代码,像这样:
matrix /= matrix.sum(axis=0)
如果这不是你想要的,而且你仍然觉得需要一个快速的C扩展,我强烈建议你用cython来写,而不是直接用C。这样可以省去很多麻烦和复杂的代码封装,让你写出来的代码看起来像Python,但在大多数情况下可以跑得和C一样快。
来回答真正的问题:SWIG并没有告诉你找不到任何类型映射。它告诉你无法应用这个类型映射 (int DIM1,int DIM2,DATA_TYPE *INPLACE_ARRAY2)
,原因是没有为 DATA_TYPE *
定义类型映射。你需要告诉它你想把这个映射应用到 double*
上:
%apply ( int DIM1, int DIM2, double* INPLACE_ARRAY2 )
{(size_t nrow, size_t ncol, double* mat)};
除非你有很好的理由,否则你应该使用cython来连接C语言和Python。我们现在开始在numpy和scipy内部使用cython,而不是直接用C语言。
你可以在我的scikits talkbox中看到一个简单的例子(因为cython自那以后有了很大的改进,我觉得现在可以写得更好)。
def cslfilter(c_np.ndarray b, c_np.ndarray a, c_np.ndarray x):
"""Fast version of slfilter for a set of frames and filter coefficients.
More precisely, given rank 2 arrays for coefficients and input, this
computes:
for i in range(x.shape[0]):
y[i] = lfilter(b[i], a[i], x[i])
This is mostly useful for processing on a set of windows with variable
filters, e.g. to compute LPC residual from a signal chopped into a set of
windows.
Parameters
----------
b: array
recursive coefficients
a: array
non-recursive coefficients
x: array
signal to filter
Note
----
This is a specialized function, and does not handle other types than
double, nor initial conditions."""
cdef int na, nb, nfr, i, nx
cdef double *raw_x, *raw_a, *raw_b, *raw_y
cdef c_np.ndarray[double, ndim=2] tb
cdef c_np.ndarray[double, ndim=2] ta
cdef c_np.ndarray[double, ndim=2] tx
cdef c_np.ndarray[double, ndim=2] ty
dt = np.common_type(a, b, x)
if not dt == np.float64:
raise ValueError("Only float64 supported for now")
if not x.ndim == 2:
raise ValueError("Only input of rank 2 support")
if not b.ndim == 2:
raise ValueError("Only b of rank 2 support")
if not a.ndim == 2:
raise ValueError("Only a of rank 2 support")
nfr = a.shape[0]
if not nfr == b.shape[0]:
raise ValueError("Number of filters should be the same")
if not nfr == x.shape[0]:
raise ValueError, \
"Number of filters and number of frames should be the same"
tx = np.ascontiguousarray(x, dtype=dt)
ty = np.ones((x.shape[0], x.shape[1]), dt)
na = a.shape[1]
nb = b.shape[1]
nx = x.shape[1]
ta = np.ascontiguousarray(np.copy(a), dtype=dt)
tb = np.ascontiguousarray(np.copy(b), dtype=dt)
raw_x = <double*>tx.data
raw_b = <double*>tb.data
raw_a = <double*>ta.data
raw_y = <double*>ty.data
for i in range(nfr):
filter_double(raw_b, nb, raw_a, na, raw_x, nx, raw_y)
raw_b += nb
raw_a += na
raw_x += nx
raw_y += nx
return ty
如你所见,除了在Python中通常会做的参数检查外,其他几乎都是一样的(filter_double是一个函数,如果你愿意,可以在一个单独的库中用纯C语言来编写)。当然,由于这是编译后的代码,如果不检查参数,可能会导致你的解释器崩溃,而不是抛出异常(不过,最近的cython提供了多种安全性和速度之间的权衡选择)。