BLAS的sgemm/dgemm是如何实现的?

3 投票
2 回答
3320 浏览
提问于 2025-04-16 15:58

我正在尝试在Python中使用ctypes来调用BLAS库里的sgemm函数。这个函数的作用是解决C = A x B这个方程,下面的代码运行得很好:

no_trans = c_char("n")
m = c_int(number_of_rows_of_A)
n = c_int(number_of_columns_of_B)
k = c_int(number_of_columns_of_A)
one = c_float(1.0)
zero = c_float(0.0)

blaslib.sgemm_(byref(no_trans), byref(no_trans), byref(m), byref(n), byref(k),
               byref(one), A, byref(m), B, byref(k), byref(zero), C, byref(m))

现在我想解决这个方程:C = A' x A,其中A'A的转置。虽然下面的代码没有报错,但返回的结果却是错误的:

trans = c_char("t")
no_trans = c_char("n")
m = c_int(number_of_rows_of_A)
n = c_int(number_of_columns_of_A)
one = c_float(1.0)
zero = c_float(0.0)

blaslib.sgemm_(byref(trans), byref(no_trans), byref(n), byref(n), byref(m),
               byref(one), A, byref(m), A, byref(m), byref(zero), C, byref(n))

为了测试,我插入了一个矩阵A = [1 2; 3 4]。正确的结果应该是C = [10 14; 14 20],但是sgemm函数给出的结果是C = [5 11; 11 25]

根据我的理解,矩阵A不需要我自己去转置,因为算法会处理这个问题。那么在第二种情况下,我传递参数时出了什么问题呢?

任何帮助、链接、文章或建议都非常感谢!

2 个回答

1

你得到的结果表明,sgemm 计算的是 A*A',而不是你想要的 A'*A。解决这个问题很简单,只需要把这两个输入的位置调换一下就可以了。

7

Blas 通常使用列优先的矩阵(就像 Fortran 语言那样),所以 A = [1 2; 3 4] 的意思是

    |1 3|   
A = |   |
    |2 4|

而且这个结果是正确的(假设你的 Python 库也是这样处理的)。可以查看这个 说明文档

撰写回答