numpy dot在使用scipy.linalg.blas.sgemm时对大数组返回无效值

3 投票
1 回答
690 浏览
提问于 2025-04-17 22:03

我正在尝试计算 A • AT

# These are my dummy values for testing
A = np.ones((150000,265),dtype=np.float32, order='F')
A_T = np.ones((265, 150000),dtype=np.float32, order='F')

out = scipy.linalg.blas.sgemm(alpha=1.0, a=A, b=A_T)

两分钟后:

In [7]: out
Out[7]: 
array([[ 265.,  265.,  265., ...,    0.,    0.,    0.],
       [ 265.,  265.,  265., ...,    0.,    0.,    0.],
       [ 265.,  265.,  265., ...,    0.,    0.,    0.],
       ..., 
       [ 265.,  265.,  265., ...,    0.,    0.,    0.],
       [ 265.,  265.,  265., ...,    0.,    0.,    0.],
       [ 265.,  265.,  265., ...,    0.,    0.,    0.]])

In [10]: out.shape
Out[10]: (150000, 150000)

注意到那些零了吗?我搞不懂了……我试着用64位浮点数,结果还是一样。

从35468开始,数组全是零。

In [39]: out[0,35468]
Out[39]: 0.0

In [9]: scipy.__version__
Out[9]: '0.12.1'

更新/编辑:

我相当确定,np.dot 是在调用 *gemm 方法。

In [1]: A = np.ones((150000,265), dtype=np.float32, order='F')

In [2]: A_T = np.ones((265, 150000),dtype=np.float32, order='F')

In [3]: out = A.dot(A_T)

In [4]: out.shape
Out[4]: (150000, 150000)

In [5]: out
Out[5]: 
array([[ 265.,  265.,  265., ...,  265.,  265.,  265.],
   [ 265.,  265.,  265., ...,  265.,  265.,  265.],
   [ 265.,  265.,  265., ...,  265.,  265.,  265.],
   ..., 
   [   0.,    0.,    0., ...,    0.,    0.,    0.],
   [   0.,    0.,    0., ...,    0.,    0.,    0.],
   [   0.,    0.,    0., ...,    0.,    0.,    0.]], dtype=float32)

1 个回答

0

在我的电脑上,你的例子会出现内存错误np.dot(A, A.T)也是一样。

当我减小矩阵的大小时,它就能正确运行了。

为了进一步调试,可以尝试直接使用Fortran的BLAS调用,不用Python。如果这样能通过,就考虑在scipy的跟踪系统上报告一个bug

顺便提一下,从scipy 0.13版本开始,你可以使用syrk

撰写回答