跟踪joblib.Parallel执行进度

65 投票
11 回答
40349 浏览
提问于 2025-04-18 14:57

有没有简单的方法来跟踪 joblib.Parallel 执行的整体进度呢?

我有一个需要很长时间才能完成的任务,这个任务由成千上万个小任务组成,我想跟踪这些任务的进度并记录到数据库里。不过,为了做到这一点,每当 Parallel 完成一个小任务时,我需要它执行一个回调,告诉我还有多少个小任务没完成。

我之前用 Python 的标准库 multiprocessing.Pool 做过类似的事情,方法是启动一个线程来记录 Pool 中待处理任务的数量。

看代码的时候发现,Parallel 是从 Pool 继承过来的,所以我以为可以用同样的方法来实现,但它似乎没有使用那个任务列表,我也没找到其他方法来“读取”它的内部状态。

11 个回答

10

简短解决方案:

这个方法在使用 Python 3.5 的 joblib 0.14.0 和 tqdm 4.46.0 时有效。感谢 frenzykryger 提供的 contextlib 建议,以及 dano 和 Connor 的猴子补丁思路。

import contextlib
import joblib
from tqdm import tqdm
from joblib import Parallel, delayed

@contextlib.contextmanager
def tqdm_joblib(tqdm_object):
    """Context manager to patch joblib to report into tqdm progress bar given as argument"""

    def tqdm_print_progress(self):
        if self.n_completed_tasks > tqdm_object.n:
            n_completed = self.n_completed_tasks - tqdm_object.n
            tqdm_object.update(n=n_completed)

    original_print_progress = joblib.parallel.Parallel.print_progress
    joblib.parallel.Parallel.print_progress = tqdm_print_progress

    try:
        yield tqdm_object
    finally:
        joblib.parallel.Parallel.print_progress = original_print_progress
        tqdm_object.close()

你可以按照 frenzykryger 描述的方式使用这个方法。

import time
def some_method(wait_time):
    time.sleep(wait_time)

with tqdm_joblib(tqdm(desc="My method", total=10)) as progress_bar:
    Parallel(n_jobs=2)(delayed(some_method)(0.2) for i in range(10))

详细解释:

Jon 提出的解决方案简单易行,但它只测量已分配的任务。如果某个任务耗时较长,进度条会在等待最后一个已分配任务完成时停留在 100%。

frenzykryger 提出的上下文管理器方法,经过 dano 和 Connor 的改进,更好一些,但 BatchCompletionCallBack 也可能在任务完成之前就被 ImmediateResult 调用(具体可以参考 joblib 的中间结果)。这样会导致我们得到的计数超过 100%。

与其对 BatchCompletionCallBack 进行猴子补丁,不如直接对 Parallel 中的 print_progress 函数进行补丁。因为 BatchCompletionCallBack 本身就会调用这个 print_progress。如果设置了详细输出(比如 Parallel(n_jobs=2, verbose=100)),print_progress 会打印出已完成的任务,虽然不如 tqdm 那样美观。从代码来看,print_progress 是一个类方法,所以它已经有 self.n_completed_tasks 来记录我们想要的数量。我们只需要将这个数量与 joblib 当前的进度状态进行比较,只有在有差异时才进行更新。

这个方法在使用 Python 3.5 的 joblib 0.14.0 和 tqdm 4.46.0 中进行了测试。

11

这是对dano回答的进一步说明,主要是针对最新版本的joblib库。内部实现上有一些变化。

from joblib import Parallel, delayed
from collections import defaultdict

# patch joblib progress callback
class BatchCompletionCallBack(object):
  completed = defaultdict(int)

  def __init__(self, time, index, parallel):
    self.index = index
    self.parallel = parallel

  def __call__(self, index):
    BatchCompletionCallBack.completed[self.parallel] += 1
    print("done with {}".format(BatchCompletionCallBack.completed[self.parallel]))
    if self.parallel._original_iterator is not None:
      self.parallel.dispatch_next()

import joblib.parallel
joblib.parallel.BatchCompletionCallBack = BatchCompletionCallBack
22

你链接的文档提到,Parallel有一个可选的进度显示功能。这个功能是通过使用multiprocessing.Pool.apply_async提供的callback关键字参数来实现的:

# This is inside a dispatch function
self._lock.acquire()
job = self._pool.apply_async(SafeFunction(func), args,
            kwargs, callback=CallBack(self.n_dispatched, self))
self._jobs.append(job)
self.n_dispatched += 1

...

class CallBack(object):
    """ Callback used by parallel: it is used for progress reporting, and
        to add data to be processed
    """
    def __init__(self, index, parallel):
        self.parallel = parallel
        self.index = index

    def __call__(self, out):
        self.parallel.print_progress(self.index)
        if self.parallel._original_iterable:
            self.parallel.dispatch_next()

接下来是print_progress的内容:

def print_progress(self, index):
    elapsed_time = time.time() - self._start_time

    # This is heuristic code to print only 'verbose' times a messages
    # The challenge is that we may not know the queue length
    if self._original_iterable:
        if _verbosity_filter(index, self.verbose):
            return
        self._print('Done %3i jobs       | elapsed: %s',
                    (index + 1,
                     short_format_time(elapsed_time),
                    ))
    else:
        # We are finished dispatching
        queue_length = self.n_dispatched
        # We always display the first loop
        if not index == 0:
            # Display depending on the number of remaining items
            # A message as soon as we finish dispatching, cursor is 0
            cursor = (queue_length - index + 1
                      - self._pre_dispatch_amount)
            frequency = (queue_length // self.verbose) + 1
            is_last_item = (index + 1 == queue_length)
            if (is_last_item or cursor % frequency):
                return
        remaining_time = (elapsed_time / (index + 1) *
                    (self.n_dispatched - index - 1.))
        self._print('Done %3i out of %3i | elapsed: %s remaining: %s',
                    (index + 1,
                     queue_length,
                     short_format_time(elapsed_time),
                     short_format_time(remaining_time),
                    ))

老实说,他们实现这个功能的方式有点奇怪——看起来是默认任务会按照启动的顺序完成。传给print_progressindex变量其实就是任务开始时的self.n_dispatched变量。所以第一个启动的任务总是会以index为0结束,即使第三个任务先完成。这也意味着他们并没有真正跟踪已完成任务的数量。所以你没有可以监控的实例变量。

我觉得你最好的办法是自己创建一个回调类,并修改Parallel的行为:

from math import sqrt
from collections import defaultdict
from joblib import Parallel, delayed

class CallBack(object):
    completed = defaultdict(int)

    def __init__(self, index, parallel):
        self.index = index
        self.parallel = parallel

    def __call__(self, index):
        CallBack.completed[self.parallel] += 1
        print("done with {}".format(CallBack.completed[self.parallel]))
        if self.parallel._original_iterable:
            self.parallel.dispatch_next()

import joblib.parallel
joblib.parallel.CallBack = CallBack

if __name__ == "__main__":
    print(Parallel(n_jobs=2)(delayed(sqrt)(i**2) for i in range(10)))

输出:

done with 1
done with 2
done with 3
done with 4
done with 5
done with 6
done with 7
done with 8
done with 9
done with 10
[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]

这样一来,每当一个任务完成时,你的回调就会被调用,而不是默认的那个。

90

再进一步,可以把整个过程封装成一个上下文管理器:

import contextlib
import joblib
from tqdm import tqdm

@contextlib.contextmanager
def tqdm_joblib(tqdm_object):
    """Context manager to patch joblib to report into tqdm progress bar given as argument"""
    class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
        def __call__(self, *args, **kwargs):
            tqdm_object.update(n=self.batch_size)
            return super().__call__(*args, **kwargs)

    old_batch_callback = joblib.parallel.BatchCompletionCallBack
    joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
    try:
        yield tqdm_object
    finally:
        joblib.parallel.BatchCompletionCallBack = old_batch_callback
        tqdm_object.close()

这样你就可以像这样使用它,完成后不会留下修改过的代码:

from math import sqrt
from joblib import Parallel, delayed

with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar:
    Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10))

我觉得这非常棒,而且看起来和tqdm与pandas的结合很相似。

28

为什么不能直接用 tqdm 呢?下面这个对我来说是有效的

from joblib import Parallel, delayed
from datetime import datetime
from tqdm import tqdm

def myfun(x):
    return x**2

results = Parallel(n_jobs=8)(delayed(myfun)(i) for i in tqdm(range(1000))
100%|██████████| 1000/1000 [00:00<00:00, 10563.37it/s]

撰写回答