Python中的并行递归函数

10 投票
3 回答
10367 浏览
提问于 2025-04-17 00:29

我该如何在Python中实现递归函数的并行处理呢?

我的函数是这样的:

def f(x, depth):
    if x==0:
        return ...
    else :
        return [x] + map(lambda x:f(x, depth-1), list_of_values(x))

def list_of_values(x):
    # Heavy compute, pure function

当我尝试用 multiprocessing.Pool.map 来进行并行处理时,Windows会打开无数个进程,导致程序卡住。

有没有什么好的(最好是简单的)方法可以在单台多核机器上实现并行处理呢?

以下是导致程序卡住的代码:

from multiprocessing import Pool
pool = pool(processes=4)
def f(x, depth):
    if x==0:
        return ...
    else :
        return [x] + pool.map(lambda x:f(x, depth-1), list_of_values(x))

def list_of_values(x):
    # Heavy compute, pure function

3 个回答

1

我一开始会保存主进程的ID,然后把它传递给子程序。

当我需要启动一个多进程的工作时,我会先检查主进程的子进程数量。如果子进程的数量小于或等于我电脑CPU核心数量的一半,我就会让它们并行运行。如果子进程的数量大于CPU核心数量的一半,我就会让它们顺序运行。这样做可以避免瓶颈,并有效利用CPU核心。你可以根据自己的情况调整核心数量。例如,你可以设置为CPU核心的确切数量,但不要超过这个数量。

def subProgramhWrapper(func, args):
    func(*args)

parent = psutil.Process(main_process_id)
children = parent.children(recursive=True)
num_cores = int(multiprocessing.cpu_count()/2)

if num_cores >= len(children):
    #parallel run
    pool = MyPool(num_cores)
    results = pool.starmap(subProgram, input_params)
    pool.close()
    pool.join()
else:
    #serial run
    for input_param in input_params:
        subProgramhWrapper(subProgram, input_param)
3

经过一番思考,我找到了一个简单的答案,虽然不算完整,但已经足够好了:

# A partially parallel solution. Just do the first level of recursion in parallel. It might be enough work to fill all cores.
import multiprocessing

def f_helper(data):
     return f(x=data['x'],depth=data['depth'], recursion_depth=data['recursion_depth'])

def f(x, depth, recursion_depth):
    if depth==0:
        return ...
    else :
        if recursion_depth == 0:
            pool = multiprocessing.Pool(processes=4)
            result = [x] + pool.map(f_helper, [{'x':_x, 'depth':depth-1,  'recursion_depth':recursion_depth+1 } _x in list_of_values(x)])
            pool.close()
        else:
            result = [x] + map(f_helper, [{'x':_x, 'depth':depth-1, 'recursion_depth':recursion_depth+1 } _x in list_of_values(x)])


        return result

def list_of_values(x):
    # Heavy compute, pure function
8

好的,抱歉给你带来了麻烦。

我将回答一个稍微不同的问题,假设f()这个函数返回列表中所有值的总和。这样做是因为从你的例子中,我不太清楚f()的返回类型是什么,而用整数来做会让代码更容易理解。

这个问题比较复杂,因为有两个不同的事情同时在进行:

  1. 在池中计算耗时的函数
  2. 递归地展开f()

我非常小心,只使用池来计算耗时的函数。这样我们就不会出现“进程爆炸”的情况,但由于这是异步的,我们需要将很多工作推迟到回调中,也就是当工作完成后,工人会调用这个回调。

更重要的是,我们需要使用一个倒计时锁,以便知道所有对f()的单独调用何时完成。

可能有更简单的方法(我很确定有,但我还有其他事情要做),不过也许这能给你一个关于可能性的想法:

from multiprocessing import Pool, Value, RawArray, RLock
from time import sleep

class Latch:

    '''A countdown latch that lets us wait for a job of "n" parts'''

    def __init__(self, n):
        self.__counter = Value('i', n)
        self.__lock = RLock()

    def decrement(self):
        with self.__lock:
            self.__counter.value -= 1
            print('dec', self.read())
        return self.read() == 0

    def read(self):
        with self.__lock:
            return self.__counter.value

    def join(self):
        while self.read():
            sleep(1)


def list_of_values(x):
    '''An expensive function'''
    print(x, ': thinking...')
    sleep(1)
    print(x, ': thought')
    return list(range(x))


pool = Pool()


def async_f(x, on_complete=None):
    '''Return the sum of the values in the expensive list'''
    if x == 0:
        on_complete(0) # no list, return 0
    else:
        n = x # need to know size of result beforehand
        latch = Latch(n) # wait for n entires to be calculated
        result = RawArray('i', n+1) # where we will assemble the map
        def delayed_map(values):
            '''This is the callback for the pool async process - it runs
               in a separate thread within this process once the
               expensive list has been calculated and orchestrates the
               mapping of f over the result.'''
            result[0] = x # first value in list is x
            for (v, i) in enumerate(values):
                def callback(fx, i=i):
                    '''This is the callback passed to f() and is called when
                       the function completes.  If it is the last of all the
                       calls in the map then it calls on_complete() (ie another
                       instance of this function) for the calling f().'''
                    result[i+1] = fx
                    if latch.decrement(): # have completed list
                        # at this point result contains [x]+map(f, ...)
                        on_complete(sum(result)) # so return sum
                async_f(v, callback)
        # Ask worker to generate list then call delayed_map
        pool.apply_async(list_of_values, [x], callback=delayed_map)


def run():
    '''Tie into the same mechanism as above, for the final value.'''
    result = Value('i')
    latch = Latch(1)
    def final_callback(value):
        result.value = value
        latch.decrement()
    async_f(6, final_callback)
    latch.join() # wait for everything to complete
    return result.value


print(run())

顺便说一下,我使用的是Python 3.2,上面的代码看起来有点复杂,因为我们在稍后才计算最终结果(向上回溯)。可能像生成器或未来对象这样的东西可以简化这个过程。

另外,我怀疑你需要一个缓存,以避免在用相同参数调用耗时函数时不必要地重新计算。

还可以看看 yaniv的回答 - 这似乎是通过明确深度来改变评估顺序的另一种方法。

撰写回答