如何使我的pyspark代码在大量数字范围内可扩展?
我正在尝试写一个pyspark脚本,用来生成所有小于等于某个给定数字的质数。例如,生成所有小于等于10亿的质数。
我写了以下代码。对于小数字来说效果不错,但一旦数字达到1亿,脚本的表现就很差了。
from pyspark.sql import SparkSession
from math import isqrt
def is_perfect_square(num):
root = isqrt(num)
return root*root == num
def sieve_of_eratosthenes_partition(iterator):
upper_limit = max(iterator) # Upper limit for prime number generation
prime_flag = [True] * len(iterator) # Initialize boolean array for primes
result = []
cur_prime = 2
while cur_prime * cur_prime <= upper_limit:
i = 0
if ((cur_prime % 2 == 0 and cur_prime != 2) or is_perfect_square(cur_prime)):
cur_prime += 1
else:
for num in iterator:
if(num % cur_prime == 0 and num != cur_prime):
prime_flag[i] = False
i += 1
cur_prime += 1
for num, is_prime in zip(iterator, prime_flag):
if is_prime and num > 1:
result.append(num)
return result
spark = SparkSession.builder.appName("SievePrimesMapPartitions").getOrCreate()
n = 10**7 # End range
numbers = spark.sparkContext.parallelize(range(1, n+1), 1000)
result_rdd = numbers.mapPartitions(sieve_of_eratosthenes_partition)
# result_rdd.map(str).saveAsTextFile("primes")
primes = result_rdd.collect()
print(primes)
print(len(primes))
我把我的范围(1到1000万)分成了1000个部分,比如(1.....10,000),(10,001....20,000)等等。对于每个部分,我都在应用筛选函数。这个筛选函数会逐步过滤掉给定范围内所有质数的倍数。
我能看到我脚本的瓶颈。在最小数字的部分,比如(1...10,000),筛选函数最多只会迭代到100。而在最大数字的部分,也就是最后一个部分(9,990,000....10,000,000),筛选函数会迭代到10百万的平方根。实际上,我脚本的性能是由处理最大数字部分所花的时间决定的。
我该如何改进这个呢?有没有其他方法可以划分我的数据集?我想到的另一个办法是先做一个筛选,直到给定数字的平方根。然后把这个筛选分发到各个节点。在每个节点上,过滤掉所有对应筛选的倍数,最后把结果合并,得到质数列表。这样做会有改善吗?
1 个回答
1
正如你所想的,提升埃拉托斯特尼筛法的效率一个好方法是,先创建一个小的筛子,范围到你要处理的数字的平方根,然后把这个筛子分发到你的计算集群中。
下面是你可以实现这个方法的步骤:
def soe_sqrt(n):
limit = isqrt(n)
prime_flag = [True] * (limit + 1)
primes = []
for p in range(2, limit + 1):
if prime_flag[p]:
primes.append(p)
for i in range(p * p, limit + 1, p):
prime_flag[i] = False
return primes
然后你把它广播到你的Spark集群上:
spark = SparkSession.builder.appName("SievePrimesMapPartitions").getOrCreate()
n = 10**9 # 1 billion
primes_upto_sqrt_n = soe_sqrt(n)
primes_broadcast = spark.sparkContext.broadcast(primes_upto_sqrt_n)
现在你只需要让你的集群使用这个广播的质数列表,来过滤掉它们的倍数。
def soe_partition(primes_broadcast):
def filter_primes(iterator):
primes = primes_broadcast.value
numbers = set(iterator)
for prime in primes:
multiples = {prime * i for i in range(2, (max(numbers) // prime) + 1)}
numbers -= multiples
return list(numbers)
return filter_primes
numbers = spark.sparkContext.parallelize(range(2, n + 1), 1000)
result_rdd = numbers.mapPartitions(soe_partition(primes_broadcast))
最后,你只需要像之前那样打印出结果,哇哦,成千上万的质数就到你手边了 :)