为什么Python在简单的for循环中如此慢?

59 投票
11 回答
87242 浏览
提问于 2025-04-17 06:08

我们正在用Python实现一些kNNSVD的功能。其他人选择了Java。我们的执行时间差别很大。我使用了cProfile来查看哪里出错,但实际上一切都还不错,没问题。是的,我也在用numpy。不过我想问一个简单的问题。

total = 0.0
for i in range(9999): # xrange is slower according 
    for j in range(1, 9999):            #to my test but more memory-friendly.
        total += (i / j)
print total

这段代码在我的电脑上运行需要31.40秒。

而同样的代码在Java上运行只需要1秒钟或者更少。我想这段代码的主要问题在于类型检查。但我在我的项目中需要进行很多这样的操作,我觉得9999乘以9999并不是一个很大的数字。

我觉得我可能犯了错误,因为我知道Python被很多科学项目使用。但是为什么这段代码这么慢,我该如何处理比这更大的问题呢?

我应该使用像Psyco这样的JIT编译器吗?

编辑

我还想说,这个循环问题只是一个例子。代码并没有这么简单,可能很难把你的改进或代码示例实际应用到我的代码中。

另一个问题是,如果我正确使用numpyscipy,我能否实现很多数据挖掘和机器学习算法?

11 个回答

10

我觉得在循环操作上,NumPy可能比CPython快(我没有在PyPy上测试过)。

我想从Joe Kington的代码开始,因为这个答案用了NumPy。

%timeit f3(9999)
704 ms ± 2.33 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

我自己:

def f4(num):
    x=np.ones(num-1)
    y=np.arange(1,num)
    return np.sum(np.true_divide(x,y))*np.sum(y)

155 µs ± 284 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

另外,高中数学可以让问题变得更简单,方便计算机处理。

Problem= (1+2+...+(num-1)) * (1/1+1/2+...+1/(num-1))
1+2+...+(num-1)=np.sum(np.arange(1,num))=num*(num-1)/2
1/1+1/2+...+1/(num-1)=np.true_divide (1,y)=np.reciprocal(y.astype(np.float64))

所以,

def f5(num):
    return np.sum(np.reciprocal(np.arange(1, num).astype(np.float64))) * num*(num-1)/2
%timeit f5(9999)
106 µs ± 615 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

此外,大学数学能让计算机处理问题时更简单。

1/1+1/2+...+1/(num-1)=np.log(num-1)+1/(2*num-2)+np.euler_gamma
(n>2)

np.euler_gamma:欧拉-马歇罗尼常数(0.57721566...)

因为NumPy中的欧拉-马歇罗尼常数不够准确,你可能会失去一些精度,比如 489223499.9991845 变成了 489223500.0408554。 如果你能忽略0.0000000085%的不准确性,你可以节省更多时间。

def f6(num):
    return (np.log(num-1)+1/(2*num-2)+np.euler_gamma)* num*(num-1)/2
%timeit f6(9999)
4.82 µs ± 29.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

随着输入数据的增大,NumPy的优势会更明显。

%timeit f3(99999)
56.7 s ± 590 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit f5(99999)
534 µs ± 86.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit f5(99999999)
1.42 s ± 15.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
9.498947911958**416**e+16
%timeit f6(99999999)
4.88 µs ± 26.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
9.498947911958**506**e+16
%timeit f6(9999999999999999999)
17.9 µs ± 921 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

在特殊情况下,你可以使用numba(但不总是适用)。

from numba import jit
@jit
def f7(num):
    return (np.log(num-1)+1/(2*num-2)+np.euler_gamma)* num*(num-1)/2
# same code with f6(num)

%timeit f6(999999999999999)
5.63 µs ± 29.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
f7(123) # compile f7(num)
%timeit f7(999999999999999)
331 ns ± 1.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit f7(9999)
286 ns ± 3.09 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

所以,我建议把NumPy、数学和numba结合起来使用。

18

因为你提到了科学计算的代码,可以看看 numpy。你现在做的事情可能已经有人做过了(或者说,它使用了 LAPACK 来处理像 SVD 这样的事情)。当人们提到用 Python 做科学计算时,可能不是指你在例子中那样使用。

举个简单的例子:

(如果你在用 Python 3,你的例子会使用浮点数除法。我的例子假设你在用 Python 2.x,所以用的是整数除法。如果不是,记得指定 i = np.arange(9999, dtype=np.float) 等等)

import numpy as np
i = np.arange(9999)
j = np.arange(1, 9999)
print np.divide.outer(i,j).sum()

为了给你一个时间上的概念……(这里我会用浮点数除法,而不是你例子中的整数除法):

import numpy as np

def f1(num):
    total = 0.0
    for i in range(num): 
        for j in range(1, num):
            total += (float(i) / j)
    return total

def f2(num):
    i = np.arange(num, dtype=np.float)
    j = np.arange(1, num, dtype=np.float)
    return np.divide.outer(i, j).sum()

def f3(num):
    """Less memory-hungry (and faster) version of f2."""
    total = 0.0
    j = np.arange(1, num, dtype=np.float)
    for i in xrange(num):
        total += (i / j).sum()
    return total

如果我们比较一下时间:

In [30]: %timeit f1(9999)
1 loops, best of 3: 27.2 s per loop

In [31]: %timeit f2(9999)
1 loops, best of 3: 1.46 s per loop

In [32]: %timeit f3(9999)
1 loops, best of 3: 915 ms per loop
48

为什么在这个循环的例子中,Java比Python快?

简单解释:可以把程序想象成一列货运火车,它在前进的同时自己铺轨道。轨道必须先铺好,火车才能开动。Java的货运火车可以派出成千上万的轨道铺设工人,大家一起合作,提前铺好很多公里的轨道。而Python只能一次派出一个工人,只能在火车前面铺10英尺的轨道。

Java有强类型,这让编译器可以使用即时编译(JIT)功能:(https://en.wikipedia.org/wiki/Just-in-time_compilation),这使得CPU可以提前并行获取内存和执行指令,而不是等到指令需要时再去做。Java可以“某种程度上”让你的for循环中的指令并行执行。而Python没有具体的类型,所以每条指令的工作内容都必须在执行时决定。这就导致你的电脑必须停下来,重新检查所有变量的内存。这意味着Python中的循环时间复杂度是多项式的O(n^2),而Java的循环可以是线性的O(n),因为有强类型的支持。

我觉得我可能搞错了,因为我知道Python被很多科学项目使用。

这些项目大量使用SciPy(NumPy是最重要的部分,但我听说围绕NumPy的API发展起来的生态系统更为重要),这大大加速了这些项目所需的各种操作。你做错的地方在于:你没有把你的关键代码用C语言编写。Python在开发上非常不错,但合理使用扩展模块是优化的关键(至少在处理数字时是这样)。Python在实现紧凑的内循环时表现得很糟糕。

默认的(目前最流行和广泛支持的)实现是一个简单的字节码解释器。即使是最简单的操作,比如整数除法,也可能需要数百个CPU周期,多个内存访问(类型检查就是一个常见的例子),还有几个C函数调用等等,而不是只需要几条(甚至在整数除法的情况下只需一条)指令。此外,语言设计中有很多抽象,这会增加额外的开销。如果你使用xrange,你的循环会在堆上分配9999个对象;如果你使用range,那就更多了(99999999个整数减去大约256256个小整数是缓存的)。而且,xrange版本在每次迭代时都会调用一个方法来推进——如果没有专门优化过的序列迭代,range版本也是如此。不过,它仍然需要进行一次复杂的字节码调度(当然,相比于整数除法来说)。

如果能看到JIT的效果就有意思了(我推荐PyPy而不是Psyco,后者已经不再积极开发,并且功能非常有限——不过在这个简单的例子中可能效果不错)。经过少量的迭代,它应该能生成一个几乎最优的机器代码循环,并加上一些保护措施——简单的整数比较,如果失败就跳转——以确保在列表中出现字符串时的正确性。Java可以做同样的事情,只是速度更快(它不需要先进行追踪),而且保护措施更少(至少在使用int时)。这就是它为什么快得多的原因。

撰写回答