Java与Python特定代码片段性能提升

1 投票
3 回答
504 浏览
提问于 2025-04-16 15:58

我有一个关于Java和Python中某段代码性能的问题。

算法:
我正在生成随机的N维点,然后对每一对在某个距离阈值内的点进行一些处理。处理的具体内容在这里不重要,因为它不会影响总的执行时间。生成这些点的过程在两种情况下都只需要几分之一秒,所以我只关心比较的部分。

执行时间:
在固定输入为3000个点和2个维度的情况下,Java的执行时间是2到4秒,而Python则需要15到200秒不等。

我对Python的执行时间有点怀疑。我的Python代码中有没有什么我遗漏的地方?有没有什么算法上的改进建议(比如预分配/重用内存,降低时间复杂度等)?


Java

double random_points[][] = new double[number_of_points][dimensions];
for(i = 0; i < number_of_points; i++)
  for(d = 0; d < dimensions; d++)
    random_points[i][d] = Math.random();

double p1[], p2[];
for(i = 0; i < number_of_points; i++)
{
  p1 = random_points[i];
  for(j = i + 1; j < number_of_points; j++)
  {
    p2 = random_points[j];

    double sum_of_squares = 0;
    for(d = 0; d < DIM_; d++)
      sum_of_squares += (p2[d] - p1[d]) * (p2[d] - p1[d]);

    double distance = Math.sqrt(ss);
    if(distance > SOME_THRESHOLD) continue;

    //...else do something with p1 and p2

  }
}

Python 3.2

random_points = [[random.random() for _d in range(0,dimensions)] for _n in range(0,number_of_points)]

for i, p1 in enumerate(random_points):
  for j, p2 in enumerate(random_points[i+1:]):
    distance = math.sqrt(sum([(p1[d]-p2[d])**2 for d in range(0,dimensions)]))
    if distance > SOME_THRESHOLD: continue

    #...else do something with p1 and p2

3 个回答

1

速度真的重要吗?这里有几个明显的加速方法:

  1. 别计算平方根。只需把你的阈值平方,然后和平方后的阈值比较就行了。

  2. 在一个维度上对你的点进行排序(就是你循环中的外层维度)。当两个点 ij 在这个维度上距离阈值更远时,继续增加 j 只会得到更远的点,这样你就可以用 continue 跳过外层循环了。

可能还有其他算法上的加速方法,即使以上这些方法的复杂度还是 O(nd),也就是在某些情况下可能会比较慢。

2

如果我处理3万个数据点和5个维度,那就要做100倍的工作量。

int number_of_points = 30000;
int dimensions = 5;
double SOME_THRESHOLD = 0.1;

long start = System.nanoTime();
double random_points[][] = new double[number_of_points][dimensions];
for (int i = 0; i < number_of_points; i++)
    for (int d = 0; d < dimensions; d++)
        random_points[i][d] = Math.random();

double p1[], p2[];
Comparator<double[]> compareX = new Comparator<double[]>() {
    @Override
    public int compare(double[] o1, double[] o2) {
        return Double.compare(o1[0], o2[0]);
    }
};
Arrays.sort(random_points, compareX);

double[] key = new double[dimensions];
int count = 0;
for (int i = 0; i < number_of_points; i++) {
    p1 = random_points[i];
    key[0] = p1[0] + SOME_THRESHOLD;
    int index = Arrays.binarySearch(random_points, key, compareX);
    if (index < 0) index = ~index;
    NEXT: for (int j = i + 1; j < index; j++) {
        p2 = random_points[j];

        double sum_of_squares = 0;
        for (int d = 0; d < dimensions; d++) {
            sum_of_squares += (p2[d] - p1[d]) * (p2[d] - p1[d]);
            if (sum_of_squares > SOME_THRESHOLD * SOME_THRESHOLD) 
                continue NEXT;
        }

        //...else do something with p1 and p2
        count++;
    }
}
long time = System.nanoTime() - start;
System.out.println("Took " + time / 1000 / 1000 + " ms. count= " + count);

打印输出

Took 1549 ms. count= 20197
5

你可能想考虑使用 numpy 这个库。

我刚刚试了以下代码:

import numpy
from scipy.spatial.distance import pdist
D=2
N=3000
p=numpy.random.uniform(size=(N,D))
dist=pdist(p, 'euclidean')

最后一行计算的是距离矩阵(这就相当于在你的代码中为每一对点计算距离)。在我的电脑上,这大约需要0.07秒。

这个方法的主要缺点是,它需要 O(n^2) 的内存来存储距离矩阵。如果这对你来说是个问题,下面的方法可能是个更好的选择:

for i in xrange(1, N):
  v = p[:N-i] - p[i:]
  dist = numpy.sqrt(numpy.sum(numpy.square(v), axis=1))
  for j in numpy.nonzero(dist > 1.4)[0]:
    print j, i+j

对于N=3000,这在我的电脑上大约需要0.33秒。

撰写回答