如何使我的pyspark代码在大量数字范围内可扩展?

1 投票
1 回答
60 浏览
提问于 2025-04-12 13:33

我正在尝试写一个pyspark脚本,用来生成所有小于等于某个给定数字的质数。例如,生成所有小于等于10亿的质数。

我写了以下代码。对于小数字来说效果不错,但一旦数字达到1亿,脚本的表现就很差了。

from pyspark.sql import SparkSession
from math import isqrt

def is_perfect_square(num):
  root = isqrt(num)
  return root*root == num


def sieve_of_eratosthenes_partition(iterator):
    upper_limit = max(iterator) # Upper limit for prime number generation
    prime_flag = [True] * len(iterator) # Initialize boolean array for primes
    result = []
    cur_prime = 2

    while cur_prime * cur_prime <= upper_limit:
        i = 0
        if ((cur_prime % 2 == 0 and cur_prime != 2) or is_perfect_square(cur_prime)):
          cur_prime += 1
        else:
          for num in iterator:
            if(num % cur_prime == 0 and num != cur_prime):
              prime_flag[i] = False
            i += 1  
          cur_prime += 1

    for num, is_prime in zip(iterator, prime_flag):
        if is_prime and num > 1:
            result.append(num)
    return result

spark = SparkSession.builder.appName("SievePrimesMapPartitions").getOrCreate()
n = 10**7 # End range 

numbers = spark.sparkContext.parallelize(range(1, n+1), 1000)
result_rdd = numbers.mapPartitions(sieve_of_eratosthenes_partition)

# result_rdd.map(str).saveAsTextFile("primes")

primes = result_rdd.collect()
print(primes)
print(len(primes))

我把我的范围(1到1000万)分成了1000个部分,比如(1.....10,000),(10,001....20,000)等等。对于每个部分,我都在应用筛选函数。这个筛选函数会逐步过滤掉给定范围内所有质数的倍数。

我能看到我脚本的瓶颈。在最小数字的部分,比如(1...10,000),筛选函数最多只会迭代到100。而在最大数字的部分,也就是最后一个部分(9,990,000....10,000,000),筛选函数会迭代到10百万的平方根。实际上,我脚本的性能是由处理最大数字部分所花的时间决定的。

我该如何改进这个呢?有没有其他方法可以划分我的数据集?我想到的另一个办法是先做一个筛选,直到给定数字的平方根。然后把这个筛选分发到各个节点。在每个节点上,过滤掉所有对应筛选的倍数,最后把结果合并,得到质数列表。这样做会有改善吗?

1 个回答

1

正如你所想的,提升埃拉托斯特尼筛法的效率一个好方法是,先创建一个小的筛子,范围到你要处理的数字的平方根,然后把这个筛子分发到你的计算集群中。

下面是你可以实现这个方法的步骤:

    def soe_sqrt(n):
        limit = isqrt(n)
        prime_flag = [True] * (limit + 1)
        primes = []
        for p in range(2, limit + 1):
            if prime_flag[p]:
                primes.append(p)
                for i in range(p * p, limit + 1, p):
                    prime_flag[i] = False
        return primes

然后你把它广播到你的Spark集群上:

    spark = SparkSession.builder.appName("SievePrimesMapPartitions").getOrCreate()
    n = 10**9  # 1 billion
    primes_upto_sqrt_n = soe_sqrt(n)
    primes_broadcast = spark.sparkContext.broadcast(primes_upto_sqrt_n)

现在你只需要让你的集群使用这个广播的质数列表,来过滤掉它们的倍数。

    def soe_partition(primes_broadcast):
        def filter_primes(iterator):
            primes = primes_broadcast.value
            numbers = set(iterator)
            for prime in primes:
                multiples = {prime * i for i in range(2, (max(numbers) // prime) + 1)}
                numbers -= multiples
            return list(numbers)

        return filter_primes


    numbers = spark.sparkContext.parallelize(range(2, n + 1), 1000)
    result_rdd = numbers.mapPartitions(soe_partition(primes_broadcast))

最后,你只需要像之前那样打印出结果,哇哦,成千上万的质数就到你手边了 :)

撰写回答