for循环比纯numpy函数更快

2 投票
1 回答
86 浏览
提问于 2025-04-14 16:21

[说明:在写这个问题的过程中,我发现了问题的所在,所以我会提出自己的解决方案。如果你有其他建议,我很乐意听取。]

我正在研究一种机器学习的方法,这个方法需要对小图像的多个部分进行大量的数学和图像处理操作。现在我想要做的是通过优化这些函数来加快处理速度。最开始,每个训练样本都是在一个循环中单独处理的,所以我想,简单的改进就是把二维的numpy数组(图像)合并成一个三维数组,然后一起处理,而不使用循环。我做了一些性能测试,结果让我大吃一惊:使用循环处理的速度竟然比纯粹的numpy计算还要快。

这是我所做的(完整的脚本在最后)。我使用timeit来进行测量。设置代码如下(init_script):

import numpy as np
import cv2
a = np.random.rand(30, 100, 100).astype(np.float32)
b = np.random.rand(30, 100, 100).astype(np.float32)
out = np.empty_like(a)

我比较了逐像素计算平均值的速度:

  • 在循环中("np loop"):

    for i, (a_slice, b_slice) in enumerate(zip(a, b)):
        out[i] = (a_slice+b_slice)/2
    
  • 对整个数组进行操作("pure np"):

    out[:] = (a+b)/2
    
  • 不复制("pure np (no copy)"):

    out = (a+b)/2
    

我对平方根的计算也做了同样的测试,为了完整性:

  • 循环:

    for i, a_slice in enumerate(a):
        out[i] = np.sqrt(a_slice)
    
  • numpy:

    out[:] = np.sqrt(a)
    
  • 不复制:

    out = np.sqrt(a)
    

我得到的测量结果(运行10000次):

=========mean==========
    np loop:  2939.34 ms
    pure np:  8441.93 ms
    pure np (no copy):  7417.07 ms
=========sqrt==========
    np loop:  1353.00 ms
    pure np:  4304.14 ms
    pure np (no copy):  3546.86 ms

使用更多样本时(arr_size = (100, 100, 100)):

=========mean==========
    np loop:  11125.34 ms
    pure np:  26596.88 ms
    pure np (no copy):  24165.19 ms
=========sqrt==========
    np loop:  5107.81 ms
    pure np:  13542.24 ms
    pure np (no copy):  10445.66 ms

使用更大图像时(arr_size = (30, 300, 300)):

=========mean==========
    np loop:  24655.34 ms
    pure np:  72427.00 ms
    pure np (no copy):  66605.25 ms
=========sqrt==========
    np loop:  16771.90 ms
    pure np:  38191.14 ms
    pure np (no copy):  30087.40 ms

我用很多不同的函数(maximummultipylog1p,还有一些cv2函数)进行了相同的测试,结果都是一样的。我对此感到困惑,因为最基本的Python编程原则是使用Numpy,因为它比循环快得多。有没有人知道为什么会这样?更重要的是:如果不使用循环也没有帮助,我该如何加速我的脚本呢?

这是我用于测试的完整脚本(Numpy版本:1.26.4):

import timeit

arr_size = (30, 100, 100)  # 30 images of size 100x100

init_script = f"import numpy as np;" \
              f"a = np.random.rand(*{arr_size}).astype(np.float32);" \
              f"b = np.random.rand(*{arr_size}).astype(np.float32);" \
              f"out = np.empty_like(a)"

mean_tests = {
    "np loop": "for i, (a_slice, b_slice) in enumerate(zip(a, b)): out[i] = (a_slice+b_slice)/2",
    "pure np": "out[:] = (a+b)/2",
    "pure np (no copy)": "out = (a+b)/2",
}

sqrt_tests = {
    "np loop": "for i, a_slice in enumerate(a): out[i] = np.sqrt(a_slice)",
    "pure np": "out[:] = np.sqrt(a)",
    "pure np (no copy)": "out = np.sqrt(a)",
}

tests = {"mean": mean_tests, "sqrt": sqrt_tests}

for func, test in tests.items():
    print(f"========={func}==========")
    for test_case, cmd in test.items():
        elapsed = timeit.timeit(cmd, init_script, number=10000)
        print(f"\t{test_case}: {elapsed*1000: .2f} ms")

1 个回答

1

分配“很大”的数组似乎是问题的根源。我可以通过使用numpy函数的out参数,或者在原地进行计算来加快这些计算的速度:

np.add(a, b, out=out)
out /= 2

用于计算平均值,以及

np.sqrt(a, out=out)

用于计算平方根。

我可以对循环版本做类似的处理,以提高速度:

for i, (a_slice, b_slice) in enumerate(zip(a, b)):
    np.add(a_slice, b_slice, out=out[i])
    out[i]/=2

以及

for i, a_slice in enumerate(a):
    np.sqrt(a_slice, out=out[i])

新的结果:

=========mean==========
    np loop:  3193.34 ms
    pure np:  8400.82 ms
    pure np (no copy):  7102.62 ms
    inplace np:  1165.70 ms
    np inplace loop:  2167.79 ms
=========sqrt==========
    np loop:  1428.44 ms
    pure np:  4185.10 ms
    pure np (no copy):  3459.68 ms
    inplace np:  588.06 ms
    np inplace loop:  860.07 ms

对于arr_size=(100,100,100)

=========mean==========
    np loop:  11747.72 ms
    pure np:  27096.73 ms
    pure np (no copy):  23890.63 ms
    inplace np:  5679.39 ms
    np inplace loop:  8931.12 ms
=========sqrt==========
    np loop:  5062.36 ms
    pure np:  13357.57 ms
    pure np (no copy):  10620.47 ms
    inplace np:  2488.66 ms
    np inplace loop:  3279.80 ms

对于arr_size=(30, 300, 300)

=========mean==========
    np loop:  24232.15 ms
    pure np:  78151.40 ms
    pure np (no copy):  64296.45 ms
    inplace np:  21251.48 ms
    np inplace loop:  23041.19 ms
=========sqrt==========
    np loop:  14685.71 ms
    pure np:  37801.05 ms
    pure np (no copy):  30844.35 ms
    inplace np:  9997.73 ms
    np inplace loop:  10121.31 ms

撰写回答