numpy dot在使用scipy.linalg.blas.sgemm时对大数组返回无效值
我正在尝试计算 A • AT:
# These are my dummy values for testing
A = np.ones((150000,265),dtype=np.float32, order='F')
A_T = np.ones((265, 150000),dtype=np.float32, order='F')
out = scipy.linalg.blas.sgemm(alpha=1.0, a=A, b=A_T)
两分钟后:
In [7]: out
Out[7]:
array([[ 265., 265., 265., ..., 0., 0., 0.],
[ 265., 265., 265., ..., 0., 0., 0.],
[ 265., 265., 265., ..., 0., 0., 0.],
...,
[ 265., 265., 265., ..., 0., 0., 0.],
[ 265., 265., 265., ..., 0., 0., 0.],
[ 265., 265., 265., ..., 0., 0., 0.]])
In [10]: out.shape
Out[10]: (150000, 150000)
注意到那些零了吗?我搞不懂了……我试着用64位浮点数,结果还是一样。
从35468开始,数组全是零。
In [39]: out[0,35468]
Out[39]: 0.0
In [9]: scipy.__version__
Out[9]: '0.12.1'
更新/编辑:
我相当确定,np.dot 是在调用 *gemm 方法。
In [1]: A = np.ones((150000,265), dtype=np.float32, order='F')
In [2]: A_T = np.ones((265, 150000),dtype=np.float32, order='F')
In [3]: out = A.dot(A_T)
In [4]: out.shape
Out[4]: (150000, 150000)
In [5]: out
Out[5]:
array([[ 265., 265., 265., ..., 265., 265., 265.],
[ 265., 265., 265., ..., 265., 265., 265.],
[ 265., 265., 265., ..., 265., 265., 265.],
...,
[ 0., 0., 0., ..., 0., 0., 0.],
[ 0., 0., 0., ..., 0., 0., 0.],
[ 0., 0., 0., ..., 0., 0., 0.]], dtype=float32)
1 个回答
0
在我的电脑上,你的例子会出现内存错误
,np.dot(A, A.T)
也是一样。
当我减小矩阵的大小时,它就能正确运行了。
为了进一步调试,可以尝试直接使用Fortran的BLAS调用,不用Python。如果这样能通过,就考虑在scipy的跟踪系统上报告一个bug。
顺便提一下,从scipy 0.13版本开始,你可以使用syrk
。