加速numpy.dot
我有一个使用 numpy
的脚本,运行时间大约有50%都花在了以下这段代码上:
s = numpy.dot(v1, v1)
其中
v1 = v[1:]
而 v
是一个包含4000个元素的1维 ndarray
,里面存的是 float64
类型的数据,这些数据在内存中是连续存放的(v.strides
的值是 (8,)
)。
有没有什么建议可以让这个过程更快呢?
编辑 这段代码是在英特尔的硬件上运行的。以下是我运行 numpy.show_config()
的输出:
atlas_threads_info:
libraries = ['lapack', 'ptf77blas', 'ptcblas', 'atlas']
library_dirs = ['/usr/local/atlas-3.9.16/lib']
language = f77
include_dirs = ['/usr/local/atlas-3.9.16/include']
blas_opt_info:
libraries = ['ptf77blas', 'ptcblas', 'atlas']
library_dirs = ['/usr/local/atlas-3.9.16/lib']
define_macros = [('ATLAS_INFO', '"\\"3.9.16\\""')]
language = c
include_dirs = ['/usr/local/atlas-3.9.16/include']
atlas_blas_threads_info:
libraries = ['ptf77blas', 'ptcblas', 'atlas']
library_dirs = ['/usr/local/atlas-3.9.16/lib']
language = c
include_dirs = ['/usr/local/atlas-3.9.16/include']
lapack_opt_info:
libraries = ['lapack', 'ptf77blas', 'ptcblas', 'atlas']
library_dirs = ['/usr/local/atlas-3.9.16/lib']
define_macros = [('ATLAS_INFO', '"\\"3.9.16\\""')]
language = f77
include_dirs = ['/usr/local/atlas-3.9.16/include']
lapack_mkl_info:
NOT AVAILABLE
blas_mkl_info:
NOT AVAILABLE
mkl_info:
NOT AVAILABLE
4 个回答
我能想到的加速这个操作的方法就是确保你的NumPy安装是和一个优化过的BLAS库(比如ATLAS)一起编译的。numpy.dot()
是少数几个使用BLAS的NumPy函数之一。
可能问题出在传给 dot 的数组复制上。
正如Sven所说,dot 的乘法运算依赖于BLAS操作。这些操作需要数组以连续的C语言顺序存储。如果传给 dot 的两个数组都是C_CONTIGUOUS格式,你应该能看到更好的性能。
当然,如果你传给 dot 的两个数组确实是1D格式(8,),那么你应该看到 both C_CONTIGUOUS 和 F_CONTIGUOUS 的标志都被设置为True;但如果它们是(1, 8),那么你可能会看到混合的顺序。
>>> w = NP.random.randint(0, 10, 100).reshape(100, 1)
>>> w.flags
C_CONTIGUOUS : True
F_CONTIGUOUS : False
OWNDATA : False
WRITEABLE : True
ALIGNED : True
UPDATEIFCOPY : False
另一种选择是使用来自BLAS的_GEMM,这可以通过模块scipy.linalg.fblas来使用。(这两个数组A和B显然是以Fortran顺序存储的,因为使用了fblas。)
from scipy.linalg import fblas as FB
X = FB.dgemm(alpha=1., a=A, b=B, trans_b=True)
你的数组并不大,所以ATLAS可能没发挥太大作用。你可以试试下面这个Fortran程序的运行时间。假设ATLAS的影响不大,这样可以让你了解如果没有Python的额外开销,dot()函数的速度会有多快。使用gfortran -O3编译时,我测得的速度大约是5微秒,误差在0.5微秒左右。
program test
real*8 :: x(4000), start, finish, s
integer :: i, j
integer,parameter :: jmax = 100000
x(:) = 4.65
s = 0.
call cpu_time(start)
do j=1,jmax
s = s + dot_product(x, x)
enddo
call cpu_time(finish)
print *, (finish-start)/jmax * 1.e6, s
end program test