Python代码优化(比C慢20倍)

4 投票
6 回答
4778 浏览
提问于 2025-04-15 19:41

我写了一段优化得很差的C代码,做了一个简单的数学计算:

#include <stdio.h>
#include <math.h>
#include <stdlib.h>
#define MIN(a, b) (((a) < (b)) ? (a) : (b))
#define MAX(a, b) (((a) > (b)) ? (a) : (b))


unsigned long long int p(int);
float fullCheck(int);

int main(int argc, char **argv){
  int i, g, maxNumber;
  unsigned long long int diff = 1000;

  if(argc < 2){
    fprintf(stderr, "Usage: %s maxNumber\n", argv[0]);
    return 0;
  }
  maxNumber = atoi(argv[1]);

  for(i = 1; i < maxNumber; i++){
    for(g = 1; g < maxNumber; g++){
      if(i == g)
        continue;
      if(p(MAX(i,g)) - p(MIN(i,g)) < diff &&  fullCheck(p(MAX(i,g)) - p(MIN(i,g))) && fullCheck(p(i) + p(g))){
          diff = p(MAX(i,g)) - p(MIN(i,g));
          printf("We have a couple %llu %llu with diff %llu\n", p(i), p(g), diff);
      }
    }
  }

  return 0;
}

float fullCheck(int number){
  float check = (-1 + sqrt(1 + 24 * number))/-6;
  float check2 = (-1 - sqrt(1 + 24 * number))/-6;
  if(check/1.00 == (int)check)
    return check;
  if(check2/1.00 == (int)check2)
    return check2;
  return 0;
}

unsigned long long int p(int n){
  return n * (3 * n - 1 ) / 2;
}

然后我尝试把它移植到Python上,看看会有什么反应,纯粹是为了好玩。我的第一个版本几乎是逐行转换,但运行得非常慢(在Python中超过120秒,而在C中不到1秒)。

我做了一些优化,得到了这个结果:

#!/usr/bin/env/python
from cmath import sqrt
import cProfile
from pstats import Stats

def quickCheck(n):
        partial_c = (sqrt(1 + 24 * (n)))/-6 
        c = 1/6 + partial_c
        if int(c.real) == c.real:
                return True
        c = c - 2*partial_c
        if int(c.real) == c.real:
                return True
        return False

def main():        
        maxNumber = 5000
        diff = 1000
        for i in range(1, maxNumber):
                p_i = i * (3 * i - 1 ) / 2
                for g in range(i, maxNumber):
                        if i == g:
                                continue
                        p_g = g * (3 * g - 1 ) / 2
                        if p_i > p_g:
                                ma = p_i
                                mi = p_g
                        else:
                                ma = p_g
                                mi = p_i

                        if ma - mi < diff and quickCheck(ma - mi):
                                if quickCheck(ma + mi):
                                        print ('New couple ', ma, mi)
                                        diff = ma - mi


cProfile.run('main()','script_perf')
perf = Stats('script_perf').sort_stats('time', 'calls').print_stats(10)

这个版本运行大约需要16秒,虽然比之前好,但还是比C慢了将近20倍。

我知道C在这类计算上比Python要好,但我想知道我是否遗漏了什么(在Python中,比如某个特别慢的函数之类的),能让这个函数运行得更快。

请注意,我使用的是Python 3.1.1,不知道这是否会有影响。

6 个回答

5

因为函数 p() 是单调递增的,所以你可以避免比较值,比如说如果 g > i,那么就可以推断出 p(g) > p(i)。另外,内层循环可以提前结束,因为 p(g) - p(i) >= diff 意味着 p(g+1) - p(i) >= diff 也是成立的。

为了确保正确性,我在 quickCheck 中把相等的比较改成了与一个小数值(epsilon)比较,因为直接比较浮点数的准确性是比较脆弱的。

在我的机器上,这样做把运行时间减少到了 7.8 毫秒,使用 Python 2.6 的时候。用 JIT 的 PyPy 运行时则减少到了 0.77 毫秒。

这说明在进行微优化之前,先寻找算法上的优化是很有价值的。微优化会让发现算法上的变化变得更加困难,而得到的收益却相对较小。

EPS = 0.00000001
def quickCheck(n):
    partial_c = sqrt(1 + 24*n) / -6
    c = 1/6 + partial_c
    if abs(int(c) - c) < EPS:
        return True
    c = 1/6 - partial_c
    if abs(int(c) - c) < EPS:
        return True
    return False

def p(i):
    return i * (3 * i - 1 ) / 2

def main(maxNumber):
    diff = 1000

    for i in range(1, maxNumber):
        for g in range(i+1, maxNumber):
            if p(g) - p(i) >= diff:
                break 
            if quickCheck(p(g) - p(i)) and quickCheck(p(g) + p(i)):
                print('New couple ', p(g), p(i), p(g) - p(i))
                diff = p(g) - p(i)
17

因为quickCheck这个函数被调用的次数接近2500万次,所以你可能想用缓存技术来存储答案,这样可以提高效率。

你可以在C语言和Python中都使用缓存技术。用C语言会快很多。

在每次执行quickCheck时,你都在计算1/6。我不确定Python是否会自动优化这个计算,但如果能避免重复计算一些常量值,整体速度会更快。C语言的编译器会帮你处理这些。

if condition: return True; else: return False这样的写法其实是多余的,而且还浪费时间。你可以直接写return condition,这样更简洁。

在Python 3.x中,/2会生成浮点数,而你似乎需要的是整数。你应该使用//2进行整除,这样做的效果更接近C语言的写法,不过我觉得速度上差别不大。

最后,Python一般是解释型语言,解释器的运行速度总是比C语言慢很多。

10

我把我的程序从大约7秒的运行时间缩短到了大约3秒:

  • 我提前计算了 i * (3 * i - 1 ) / 2 的值,这样在你的代码中,这个计算被重复了很多次。
  • 我缓存了对 quickCheck 函数的调用,这样可以避免重复计算。
  • 我通过把范围加1,去掉了 if i == g 的判断。
  • 我去掉了 if p_i > p_g 的判断,因为 p_i 总是小于 p_g。

我还把 quickCheck 函数放到了 main 函数里面,这样所有的变量都是局部的(局部变量查找速度比全局变量快)。我相信还有更多的小优化可以做。

def main():
        maxNumber = 5000
        diff = 1000

        p = {}
        quickCache = {}

        for i in range(maxNumber):
            p[i] = i * (3 * i - 1 ) / 2

        def quickCheck(n):
            if n in quickCache: return quickCache[n]
            partial_c = (sqrt(1 + 24 * (n)))/-6 
            c = 1/6 + partial_c
            if int(c.real) == c.real:
                    quickCache[n] = True
                    return True
            c = c - 2*partial_c
            if int(c.real) == c.real:
                    quickCache[n] = True
                    return True
            quickCache[n] = False
            return False

        for i in range(1, maxNumber):
                mi = p[i]
                for g in range(i+1, maxNumber):
                        ma = p[g]
                        if ma - mi < diff and quickCheck(ma - mi) and quickCheck(ma + mi):
                                print('New couple ', ma, mi)
                                diff = ma - mi

撰写回答