Numpy数组与ctypes函数的接口
我正在尝试将一个共享的C库与一些Python代码连接起来。这个库的接口大概是这样的:
typedef struct{
int v1;
double* v2} input;
还有两种其他类型,分别用于配置和输出。
我在Python中使用ctypes Structure
来设置这些结构,像这样:
class input(Structure):
_fields_ = [("v1",c_int),("v2",POINTER(c_double)]
C代码中有一些函数,它们接收指向这个结构的指针,参数类型定义如下:
fun.argtypes = [constraints,input,POINTER(input)]
constraints
是另一个结构,里面有一些int
类型的字段用于配置。
首先,我更新输入结构中的v2字段:
input.v2 = generated_array.ctypes.data_as(POINTER(c_double))
然后我调用这个函数:
fun(constraints,input,byref(output))
这个函数的原型要求传入结构和指向结构的指针(输出结构的类型假设与输入结构的类型相同)。
接着,我想访问输出中v2字段的结果。但是我得到了意想不到的结果。有没有更好或正确的方法来做这个?
我在这里搜索了很多,读了文档,但找不到问题所在。我没有任何错误信息,但从共享库收到的警告似乎表明这些接口有问题。
我想我找到了问题所在:
当我调用这个方法时,会调用一个复数的numpy数组。然后我创建了4个向量:
out_real = ascontiguousarray(zeros(din.size,dtype=c_double))
out_imag = ascontiguousarray(zeros(din.size,dtype=c_double))
in_real = ascontiguousarray(din.real,dtype = c_double)
in_imag = ascontiguousarray(din.imag,dtype = c_double)
其中din是输入向量。我是这样测试这个方法的:
print in_real.ctypes.data_as(POINTER(c_double))
print in_imag.ctypes.data_as(POINTER(c_double))
print out_real.ctypes.data_as(POINTER(c_double))
print out_imag.ctypes.data_as(POINTER(c_double))
结果是:
<model.LP_c_double object at 0x1d81f80>
<model.LP_c_double object at 0x1d81f80>
<model.LP_c_double object at 0x1d81f80>
<model.LP_c_double object at 0x1d81f80>
看起来它们都指向同一个地方。
经过一些修改,它按预期工作了……
经过多次测试,我发现第一次的代码几乎是正确的。我只创建了一个结构实例并更新它的字段。我改为在每次调用fun
时创建一个新的实例。我还把所有数组类型改为等效的ctypes类型;这似乎让函数按预期工作。
打印的行为仍然和上面的测试一样,但这个函数似乎即使在这种奇怪的行为下也能正常工作。这正如@ericsun在下面评论所指出的那样是正确的。
1 个回答
这个 struct
里面有一个 int
类型的字段,可能是用来表示数组的长度,不过我只是猜的,因为没有完整的函数原型。如果真是这样的话,下面有个例子可能会对你有帮助。
首先,我需要在一个共享库中编译一个测试函数。我会简单地把输入数组的每个元素都乘以 2:
import os
import numpy as np
from ctypes import *
open('tmp.c', 'w').write('''\
typedef struct {
int v1; double *v2;
} darray;
int test(darray *input, darray *output) {
int i;
/* note: this should first test for compatible size */
for (i=0; i < input->v1; i++)
*(output->v2 + i) = *(input->v2 + i) * 2;
return 0;
}
''')
os.system('gcc -shared -o tmp.so tmp.c')
接下来,创建 ctypes 的定义。我添加了一个 classmethod
,用来从 numpy.ndarray
创建一个 darray
:
c_double_p = POINTER(c_double)
class darray(Structure):
_fields_ = [
('v1', c_int),
('v2', c_double_p),
]
@classmethod
def fromnp(cls, a):
return cls(len(a), a.ctypes.data_as(c_double_p))
lib = CDLL('./tmp.so')
lib.test.argtypes = POINTER(darray), POINTER(darray)
测试:
a1 = np.arange(3) + 1.0
a2 = np.zeros(3)
print 'before:', '\na1 =', a1, '\na2 =', a2
lib.test(darray.fromnp(a1), darray.fromnp(a2))
print 'after:', '\na1 =', a1, '\na2 =', a2
输出:
before:
a1 = [ 1. 2. 3.]
a2 = [ 0. 0. 0.]
after:
a1 = [ 1. 2. 3.]
a2 = [ 2. 4. 6.]