使用Cython封装LAPACKE函数

6 投票
2 回答
765 浏览
提问于 2025-04-18 03:37

我正在尝试用 Cython 封装 LAPACK 函数 dgtsv,这个函数是用来解决三对角方程组的。

我看到过 之前的一个回答,但是因为 dgtsv 不是在 scipy.linalg 中封装的 LAPACK 函数,所以我觉得不能用那种方法。于是我开始尝试参考 这个例子

这是我 lapacke.pxd 文件的内容:

ctypedef int lapack_int

cdef extern from "lapacke.h" nogil:

    int LAPACK_ROW_MAJOR
    int LAPACK_COL_MAJOR

    lapack_int LAPACKE_dgtsv(int matrix_order,
                             lapack_int n,
                             lapack_int nrhs,
                             double * dl,
                             double * d,
                             double * du,
                             double * b,
                             lapack_int ldb)

...这是我在 _solvers.pyx 中写的简单 Cython 封装:

#!python

cimport cython
from lapacke cimport *

cpdef TDMA_lapacke(double[::1] DL, double[::1] D, double[::1] DU,
                   double[:, ::1] B):

    cdef:
        lapack_int n = D.shape[0]
        lapack_int nrhs = B.shape[1]
        lapack_int ldb = B.shape[0]
        double * dl = &DL[0]
        double * d = &D[0]
        double * du = &DU[0]
        double * b = &B[0, 0]
        lapack_int info

    info = LAPACKE_dgtsv(LAPACK_ROW_MAJOR, n, nrhs, dl, d, du, b, ldb)

    return info

...还有这是一个 Python 封装和测试脚本:

import numpy as np
from scipy import sparse
from cymodules import _solvers


def trisolve_lapacke(dl, d, du, b, inplace=False):

    if (dl.shape[0] != du.shape[0] or dl.shape[0] != d.shape[0] - 1
            or b.shape != d.shape):
        raise ValueError('Invalid diagonal shapes')

    if b.ndim == 1:
        # b is (LDB, NRHS)
        b = b[:, None]

    # be sure to force a copy of d and b if we're not solving in place
    if not inplace:
        d = d.copy()
        b = b.copy()

    # this may also force copies if arrays are improperly typed/noncontiguous
    dl, d, du, b = (np.ascontiguousarray(v, dtype=np.float64)
                    for v in (dl, d, du, b))

    # b will now be modified in place to contain the solution
    info = _solvers.TDMA_lapacke(dl, d, du, b)
    print info

    return b.ravel()


def test_trisolve(n=20000):

    dl = np.random.randn(n - 1)
    d = np.random.randn(n)
    du = np.random.randn(n - 1)

    M = sparse.diags((dl, d, du), (-1, 0, 1), format='csc')
    x = np.random.randn(n)
    b = M.dot(x)

    x_hat = trisolve_lapacke(dl, d, du, b)

    print "||x - x_hat|| = ", np.linalg.norm(x - x_hat)

不幸的是,调用 _solvers.TDMA_lapacke 时,test_trisolve 就出现了段错误。 我很确定我的 setup.py 是正确的 - 用 ldd _solvers.so 查看,_solvers.so 在运行时链接到了正确的共享库。

我现在不太确定该怎么继续 - 有什么想法吗?


简要更新

对于较小的 n 值,我通常不会立即出现段错误,但会得到一些无意义的结果(||x - x_hat|| 应该非常接近 0):

In [28]: test_trisolve2.test_trisolve(10)
0
||x - x_hat|| =  6.23202576396

In [29]: test_trisolve2.test_trisolve(10)
-7
||x - x_hat|| =  3.88623414288

In [30]: test_trisolve2.test_trisolve(10)
0
||x - x_hat|| =  2.60190676562

In [31]: test_trisolve2.test_trisolve(10)
0
||x - x_hat|| =  3.86631743386

In [32]: test_trisolve2.test_trisolve(10)
Segmentation fault

通常 LAPACKE_dgtsv 返回的代码是 0(这表示成功),但偶尔我会得到 -7,这意味着第七个参数(b)的值不合法。发生的情况是,实际上只有 b 的第一个值被修改了。如果我继续调用 test_trisolve,即使 n 很小,最终也会出现段错误。

2 个回答

1

虽然这个问题有点老,但似乎仍然很重要。观察到的行为是因为对参数 LDB 的误解:

  • Fortran 的数组是按列存储的,数组 B 的主维度对应的是 N。因此,LDB 必须大于等于 max(1,N)。
  • 而如果是按行存储,LDB 对应的是 NRHS,所以需要满足条件 LDB 必须大于等于 max(1,NRHS)。

评论 # b 的维度是 (LDB, NRHS) 是不正确的,因为 b 的维度应该是 (LDB,N),在这种情况下 LDB 应该是 1。

如果将 LAPACK_ROW_MAJOR 切换到 LAPACK_COL_MAJOR,问题就能解决,只要 NRHS 等于 1。因为按列存储的 (N,1) 和按行存储的 (1,N) 在内存中的布局是一样的。不过,如果 NRHS 大于 1,就会出错。

4

好的,我最终搞明白了——原来我之前对行优先和列优先的理解有误。

因为C语言的数组是按行优先的顺序存储的,所以我以为在调用LAPACKE_dgtsv时,应该把LAPACK_ROW_MAJOR作为第一个参数。

实际上,如果我把

info = LAPACKE_dgtsv(LAPACK_ROW_MAJOR, ...)

改成

info = LAPACKE_dgtsv(LAPACK_COL_MAJOR, ...)

那么我的函数就能正常工作了:

test_trisolve2.test_trisolve()
0
||x - x_hat|| =  6.67064747632e-12

这让我觉得有点反直觉——有没有人能解释一下这是为什么呢?

撰写回答