加速numpy.dot

13 投票
4 回答
16940 浏览
提问于 2025-04-16 17:34

我有一个使用 numpy 的脚本,运行时间大约有50%都花在了以下这段代码上:

s = numpy.dot(v1, v1)

其中

v1 = v[1:]

v 是一个包含4000个元素的1维 ndarray,里面存的是 float64 类型的数据,这些数据在内存中是连续存放的(v.strides 的值是 (8,))。

有没有什么建议可以让这个过程更快呢?

编辑 这段代码是在英特尔的硬件上运行的。以下是我运行 numpy.show_config() 的输出:

atlas_threads_info:
    libraries = ['lapack', 'ptf77blas', 'ptcblas', 'atlas']
    library_dirs = ['/usr/local/atlas-3.9.16/lib']
    language = f77
    include_dirs = ['/usr/local/atlas-3.9.16/include']

blas_opt_info:
    libraries = ['ptf77blas', 'ptcblas', 'atlas']
    library_dirs = ['/usr/local/atlas-3.9.16/lib']
    define_macros = [('ATLAS_INFO', '"\\"3.9.16\\""')]
    language = c
    include_dirs = ['/usr/local/atlas-3.9.16/include']

atlas_blas_threads_info:
    libraries = ['ptf77blas', 'ptcblas', 'atlas']
    library_dirs = ['/usr/local/atlas-3.9.16/lib']
    language = c
    include_dirs = ['/usr/local/atlas-3.9.16/include']

lapack_opt_info:
    libraries = ['lapack', 'ptf77blas', 'ptcblas', 'atlas']
    library_dirs = ['/usr/local/atlas-3.9.16/lib']
    define_macros = [('ATLAS_INFO', '"\\"3.9.16\\""')]
    language = f77
    include_dirs = ['/usr/local/atlas-3.9.16/include']

lapack_mkl_info:
  NOT AVAILABLE

blas_mkl_info:
  NOT AVAILABLE

mkl_info:
  NOT AVAILABLE

4 个回答

4

我能想到的加速这个操作的方法就是确保你的NumPy安装是和一个优化过的BLAS库(比如ATLAS)一起编译的。numpy.dot()是少数几个使用BLAS的NumPy函数之一。

7

可能问题出在传给 dot 的数组复制上。

正如Sven所说,dot 的乘法运算依赖于BLAS操作。这些操作需要数组以连续的C语言顺序存储。如果传给 dot 的两个数组都是C_CONTIGUOUS格式,你应该能看到更好的性能。

当然,如果你传给 dot 的两个数组确实是1D格式(8,),那么你应该看到 both C_CONTIGUOUS 和 F_CONTIGUOUS 的标志都被设置为True;但如果它们是(1, 8),那么你可能会看到混合的顺序。

>>> w = NP.random.randint(0, 10, 100).reshape(100, 1)
>>> w.flags
   C_CONTIGUOUS : True
   F_CONTIGUOUS : False
   OWNDATA : False
   WRITEABLE : True
   ALIGNED : True
   UPDATEIFCOPY : False


另一种选择是使用来自BLAS的_GEMM,这可以通过模块scipy.linalg.fblas来使用。(这两个数组A和B显然是以Fortran顺序存储的,因为使用了fblas。)

from scipy.linalg import fblas as FB
X = FB.dgemm(alpha=1., a=A, b=B, trans_b=True)
5

你的数组并不大,所以ATLAS可能没发挥太大作用。你可以试试下面这个Fortran程序的运行时间。假设ATLAS的影响不大,这样可以让你了解如果没有Python的额外开销,dot()函数的速度会有多快。使用gfortran -O3编译时,我测得的速度大约是5微秒,误差在0.5微秒左右。

    program test

    real*8 :: x(4000), start, finish, s
    integer :: i, j
    integer,parameter :: jmax = 100000

    x(:) = 4.65
    s = 0.
    call cpu_time(start)
    do j=1,jmax
        s = s + dot_product(x, x)
    enddo
    call cpu_time(finish)
    print *, (finish-start)/jmax * 1.e6, s

    end program test

撰写回答