使用multiprocessing.pool在Python中搜索数字

0 投票
1 回答
34 浏览
提问于 2025-04-14 16:19

我在找一个小于100,000的数字,这个数字需要满足一个特定的条件f(x)
到目前为止,我写了以下代码:

#!/usr/bin/env python

import itertools
import multiprocessing.pool

paramlist = itertools.product("0123456789", repeat=5)

def function(word):
    number = int(''.join(word))
    # some code

with multiprocessing.pool.ThreadPool(processes=8) as pool:
    pool.imap_unordered(function, paramlist)
    pool.close()
    pool.join()

有没有办法让这段代码运行得更快一些呢?

1 个回答

1

jasonharper的评论说得很对。除非你用的function工作量非常大,不然把它并行运行所节省的时间,可能还不够你创建子进程所花的额外时间。

如果你要使用多进程,我建议把区间[0, 100_000)分成N个更小的不重叠的区间,其中N是你电脑上CPU的核心数。我选择了f(x)函数(x ** 2),这个函数的特定值f(x)9801198001)的解x比较偏,所以用普通的方法去找结果(99001)时,几乎得检查所有可能的x值才能找到。即便如此,对于这么简单的函数,使用多进程的速度还是比普通方法慢10倍。

如果f(x)函数是单调递增或递减的,那么用普通方法的速度可以通过二分查找进一步加快,我也把这个方法加上了:

from multiprocessing import Pool, cpu_count

def f(x):
    return x ** 2

def search(r):
    for x in r:
        if f(x) == 9801198001:
            return x
    return None

def main():
    import time

    # Parallel processing:
    pool_size = cpu_count()
    interval_size = 100_000 // pool_size
    lower_bound = 0
    args = []
    for _ in range(pool_size - 1):
        args.append(range(lower_bound, lower_bound + interval_size))
        lower_bound += interval_size
    # Final interval:
    args.append(range(lower_bound, 100_000))

    t = time.time()
    with Pool(pool_size) as pool:
        for result in pool.imap_unordered(search, args):
            if result is not None:
                break
    # An implicit call to pool.terminate() will be called
    # to terminate any remaining submitted tasks
    elapsed = time.time() - t
    print(f'result = {result}, parallel elapsed time = {elapsed}')

    # Serial processing:
    t = time.time()
    result = search(range(100_000))
    elapsed = time.time() - t
    print(f'result = {result}, serial elapsed time = {elapsed}')

    # Serial processing using a binary search
    # for monotonically increasing function:
    t = time.time()
    lower = 0
    upper = 100_000
    result = None
    while lower < upper:
        x = lower + (upper - lower) // 2
        sq = f(x)
        if sq == 9801198001:
            result = x
            break
        if sq < 9801198001:
            lower = x + 1
        else:
            upper = x

    elapsed = time.time() - t
    print(f'result = {result}, serial binary search elapsed time = {elapsed}')

if __name__ == '__main__':
    main()

输出结果:

result = 99001, parallel elapsed time = 0.24248361587524414
result = 99001, serial elapsed time = 0.029256343841552734
result = 99001, serial binary search elapsed time = 0.0

撰写回答