使用multiprocessing.pool在Python中搜索数字
我在找一个小于100,000的数字,这个数字需要满足一个特定的条件f(x)
。
到目前为止,我写了以下代码:
#!/usr/bin/env python
import itertools
import multiprocessing.pool
paramlist = itertools.product("0123456789", repeat=5)
def function(word):
number = int(''.join(word))
# some code
with multiprocessing.pool.ThreadPool(processes=8) as pool:
pool.imap_unordered(function, paramlist)
pool.close()
pool.join()
有没有办法让这段代码运行得更快一些呢?
1 个回答
1
jasonharper的评论说得很对。除非你用的function
工作量非常大,不然把它并行运行所节省的时间,可能还不够你创建子进程所花的额外时间。
如果你要使用多进程,我建议把区间[0, 100_000)分成N个更小的不重叠的区间,其中N是你电脑上CPU的核心数。我选择了f(x)
函数(x ** 2
),这个函数的特定值f(x)
(9801198001
)的解x
比较偏,所以用普通的方法去找结果(99001
)时,几乎得检查所有可能的x
值才能找到。即便如此,对于这么简单的函数,使用多进程的速度还是比普通方法慢10倍。
如果f(x)
函数是单调递增或递减的,那么用普通方法的速度可以通过二分查找进一步加快,我也把这个方法加上了:
from multiprocessing import Pool, cpu_count
def f(x):
return x ** 2
def search(r):
for x in r:
if f(x) == 9801198001:
return x
return None
def main():
import time
# Parallel processing:
pool_size = cpu_count()
interval_size = 100_000 // pool_size
lower_bound = 0
args = []
for _ in range(pool_size - 1):
args.append(range(lower_bound, lower_bound + interval_size))
lower_bound += interval_size
# Final interval:
args.append(range(lower_bound, 100_000))
t = time.time()
with Pool(pool_size) as pool:
for result in pool.imap_unordered(search, args):
if result is not None:
break
# An implicit call to pool.terminate() will be called
# to terminate any remaining submitted tasks
elapsed = time.time() - t
print(f'result = {result}, parallel elapsed time = {elapsed}')
# Serial processing:
t = time.time()
result = search(range(100_000))
elapsed = time.time() - t
print(f'result = {result}, serial elapsed time = {elapsed}')
# Serial processing using a binary search
# for monotonically increasing function:
t = time.time()
lower = 0
upper = 100_000
result = None
while lower < upper:
x = lower + (upper - lower) // 2
sq = f(x)
if sq == 9801198001:
result = x
break
if sq < 9801198001:
lower = x + 1
else:
upper = x
elapsed = time.time() - t
print(f'result = {result}, serial binary search elapsed time = {elapsed}')
if __name__ == '__main__':
main()
输出结果:
result = 99001, parallel elapsed time = 0.24248361587524414
result = 99001, serial elapsed time = 0.029256343841552734
result = 99001, serial binary search elapsed time = 0.0