使用Cython封装LAPACKE函数
我正在尝试用 Cython 封装 LAPACK 函数 dgtsv
,这个函数是用来解决三对角方程组的。
我看到过 之前的一个回答,但是因为 dgtsv
不是在 scipy.linalg
中封装的 LAPACK 函数,所以我觉得不能用那种方法。于是我开始尝试参考 这个例子。
这是我 lapacke.pxd
文件的内容:
ctypedef int lapack_int
cdef extern from "lapacke.h" nogil:
int LAPACK_ROW_MAJOR
int LAPACK_COL_MAJOR
lapack_int LAPACKE_dgtsv(int matrix_order,
lapack_int n,
lapack_int nrhs,
double * dl,
double * d,
double * du,
double * b,
lapack_int ldb)
...这是我在 _solvers.pyx
中写的简单 Cython 封装:
#!python
cimport cython
from lapacke cimport *
cpdef TDMA_lapacke(double[::1] DL, double[::1] D, double[::1] DU,
double[:, ::1] B):
cdef:
lapack_int n = D.shape[0]
lapack_int nrhs = B.shape[1]
lapack_int ldb = B.shape[0]
double * dl = &DL[0]
double * d = &D[0]
double * du = &DU[0]
double * b = &B[0, 0]
lapack_int info
info = LAPACKE_dgtsv(LAPACK_ROW_MAJOR, n, nrhs, dl, d, du, b, ldb)
return info
...还有这是一个 Python 封装和测试脚本:
import numpy as np
from scipy import sparse
from cymodules import _solvers
def trisolve_lapacke(dl, d, du, b, inplace=False):
if (dl.shape[0] != du.shape[0] or dl.shape[0] != d.shape[0] - 1
or b.shape != d.shape):
raise ValueError('Invalid diagonal shapes')
if b.ndim == 1:
# b is (LDB, NRHS)
b = b[:, None]
# be sure to force a copy of d and b if we're not solving in place
if not inplace:
d = d.copy()
b = b.copy()
# this may also force copies if arrays are improperly typed/noncontiguous
dl, d, du, b = (np.ascontiguousarray(v, dtype=np.float64)
for v in (dl, d, du, b))
# b will now be modified in place to contain the solution
info = _solvers.TDMA_lapacke(dl, d, du, b)
print info
return b.ravel()
def test_trisolve(n=20000):
dl = np.random.randn(n - 1)
d = np.random.randn(n)
du = np.random.randn(n - 1)
M = sparse.diags((dl, d, du), (-1, 0, 1), format='csc')
x = np.random.randn(n)
b = M.dot(x)
x_hat = trisolve_lapacke(dl, d, du, b)
print "||x - x_hat|| = ", np.linalg.norm(x - x_hat)
不幸的是,调用 _solvers.TDMA_lapacke
时,test_trisolve
就出现了段错误。
我很确定我的 setup.py
是正确的 - 用 ldd _solvers.so
查看,_solvers.so
在运行时链接到了正确的共享库。
我现在不太确定该怎么继续 - 有什么想法吗?
简要更新:
对于较小的 n
值,我通常不会立即出现段错误,但会得到一些无意义的结果(||x - x_hat|| 应该非常接近 0):
In [28]: test_trisolve2.test_trisolve(10)
0
||x - x_hat|| = 6.23202576396
In [29]: test_trisolve2.test_trisolve(10)
-7
||x - x_hat|| = 3.88623414288
In [30]: test_trisolve2.test_trisolve(10)
0
||x - x_hat|| = 2.60190676562
In [31]: test_trisolve2.test_trisolve(10)
0
||x - x_hat|| = 3.86631743386
In [32]: test_trisolve2.test_trisolve(10)
Segmentation fault
通常 LAPACKE_dgtsv
返回的代码是 0
(这表示成功),但偶尔我会得到 -7
,这意味着第七个参数(b
)的值不合法。发生的情况是,实际上只有 b
的第一个值被修改了。如果我继续调用 test_trisolve
,即使 n
很小,最终也会出现段错误。
2 个回答
虽然这个问题有点老,但似乎仍然很重要。观察到的行为是因为对参数 LDB 的误解:
- Fortran 的数组是按列存储的,数组 B 的主维度对应的是 N。因此,LDB 必须大于等于 max(1,N)。
- 而如果是按行存储,LDB 对应的是 NRHS,所以需要满足条件 LDB 必须大于等于 max(1,NRHS)。
评论 # b 的维度是 (LDB, NRHS) 是不正确的,因为 b 的维度应该是 (LDB,N),在这种情况下 LDB 应该是 1。
如果将 LAPACK_ROW_MAJOR 切换到 LAPACK_COL_MAJOR,问题就能解决,只要 NRHS 等于 1。因为按列存储的 (N,1) 和按行存储的 (1,N) 在内存中的布局是一样的。不过,如果 NRHS 大于 1,就会出错。
好的,我最终搞明白了——原来我之前对行优先和列优先的理解有误。
因为C语言的数组是按行优先的顺序存储的,所以我以为在调用LAPACKE_dgtsv
时,应该把LAPACK_ROW_MAJOR
作为第一个参数。
实际上,如果我把
info = LAPACKE_dgtsv(LAPACK_ROW_MAJOR, ...)
改成
info = LAPACKE_dgtsv(LAPACK_COL_MAJOR, ...)
那么我的函数就能正常工作了:
test_trisolve2.test_trisolve()
0
||x - x_hat|| = 6.67064747632e-12
这让我觉得有点反直觉——有没有人能解释一下这是为什么呢?