跟踪joblib.Parallel执行进度
有没有简单的方法来跟踪 joblib.Parallel 执行的整体进度呢?
我有一个需要很长时间才能完成的任务,这个任务由成千上万个小任务组成,我想跟踪这些任务的进度并记录到数据库里。不过,为了做到这一点,每当 Parallel 完成一个小任务时,我需要它执行一个回调,告诉我还有多少个小任务没完成。
我之前用 Python 的标准库 multiprocessing.Pool 做过类似的事情,方法是启动一个线程来记录 Pool 中待处理任务的数量。
看代码的时候发现,Parallel 是从 Pool 继承过来的,所以我以为可以用同样的方法来实现,但它似乎没有使用那个任务列表,我也没找到其他方法来“读取”它的内部状态。
11 个回答
简短解决方案:
这个方法在使用 Python 3.5 的 joblib 0.14.0 和 tqdm 4.46.0 时有效。感谢 frenzykryger 提供的 contextlib 建议,以及 dano 和 Connor 的猴子补丁思路。
import contextlib
import joblib
from tqdm import tqdm
from joblib import Parallel, delayed
@contextlib.contextmanager
def tqdm_joblib(tqdm_object):
"""Context manager to patch joblib to report into tqdm progress bar given as argument"""
def tqdm_print_progress(self):
if self.n_completed_tasks > tqdm_object.n:
n_completed = self.n_completed_tasks - tqdm_object.n
tqdm_object.update(n=n_completed)
original_print_progress = joblib.parallel.Parallel.print_progress
joblib.parallel.Parallel.print_progress = tqdm_print_progress
try:
yield tqdm_object
finally:
joblib.parallel.Parallel.print_progress = original_print_progress
tqdm_object.close()
你可以按照 frenzykryger 描述的方式使用这个方法。
import time
def some_method(wait_time):
time.sleep(wait_time)
with tqdm_joblib(tqdm(desc="My method", total=10)) as progress_bar:
Parallel(n_jobs=2)(delayed(some_method)(0.2) for i in range(10))
详细解释:
Jon 提出的解决方案简单易行,但它只测量已分配的任务。如果某个任务耗时较长,进度条会在等待最后一个已分配任务完成时停留在 100%。
frenzykryger 提出的上下文管理器方法,经过 dano 和 Connor 的改进,更好一些,但 BatchCompletionCallBack
也可能在任务完成之前就被 ImmediateResult
调用(具体可以参考 joblib 的中间结果)。这样会导致我们得到的计数超过 100%。
与其对 BatchCompletionCallBack
进行猴子补丁,不如直接对 Parallel
中的 print_progress
函数进行补丁。因为 BatchCompletionCallBack
本身就会调用这个 print_progress
。如果设置了详细输出(比如 Parallel(n_jobs=2, verbose=100)
),print_progress
会打印出已完成的任务,虽然不如 tqdm 那样美观。从代码来看,print_progress
是一个类方法,所以它已经有 self.n_completed_tasks
来记录我们想要的数量。我们只需要将这个数量与 joblib 当前的进度状态进行比较,只有在有差异时才进行更新。
这个方法在使用 Python 3.5 的 joblib 0.14.0 和 tqdm 4.46.0 中进行了测试。
这是对dano回答的进一步说明,主要是针对最新版本的joblib库。内部实现上有一些变化。
from joblib import Parallel, delayed
from collections import defaultdict
# patch joblib progress callback
class BatchCompletionCallBack(object):
completed = defaultdict(int)
def __init__(self, time, index, parallel):
self.index = index
self.parallel = parallel
def __call__(self, index):
BatchCompletionCallBack.completed[self.parallel] += 1
print("done with {}".format(BatchCompletionCallBack.completed[self.parallel]))
if self.parallel._original_iterator is not None:
self.parallel.dispatch_next()
import joblib.parallel
joblib.parallel.BatchCompletionCallBack = BatchCompletionCallBack
你链接的文档提到,Parallel
有一个可选的进度显示功能。这个功能是通过使用multiprocessing.Pool.apply_async
提供的callback
关键字参数来实现的:
# This is inside a dispatch function
self._lock.acquire()
job = self._pool.apply_async(SafeFunction(func), args,
kwargs, callback=CallBack(self.n_dispatched, self))
self._jobs.append(job)
self.n_dispatched += 1
...
class CallBack(object):
""" Callback used by parallel: it is used for progress reporting, and
to add data to be processed
"""
def __init__(self, index, parallel):
self.parallel = parallel
self.index = index
def __call__(self, out):
self.parallel.print_progress(self.index)
if self.parallel._original_iterable:
self.parallel.dispatch_next()
接下来是print_progress
的内容:
def print_progress(self, index):
elapsed_time = time.time() - self._start_time
# This is heuristic code to print only 'verbose' times a messages
# The challenge is that we may not know the queue length
if self._original_iterable:
if _verbosity_filter(index, self.verbose):
return
self._print('Done %3i jobs | elapsed: %s',
(index + 1,
short_format_time(elapsed_time),
))
else:
# We are finished dispatching
queue_length = self.n_dispatched
# We always display the first loop
if not index == 0:
# Display depending on the number of remaining items
# A message as soon as we finish dispatching, cursor is 0
cursor = (queue_length - index + 1
- self._pre_dispatch_amount)
frequency = (queue_length // self.verbose) + 1
is_last_item = (index + 1 == queue_length)
if (is_last_item or cursor % frequency):
return
remaining_time = (elapsed_time / (index + 1) *
(self.n_dispatched - index - 1.))
self._print('Done %3i out of %3i | elapsed: %s remaining: %s',
(index + 1,
queue_length,
short_format_time(elapsed_time),
short_format_time(remaining_time),
))
老实说,他们实现这个功能的方式有点奇怪——看起来是默认任务会按照启动的顺序完成。传给print_progress
的index
变量其实就是任务开始时的self.n_dispatched
变量。所以第一个启动的任务总是会以index
为0结束,即使第三个任务先完成。这也意味着他们并没有真正跟踪已完成任务的数量。所以你没有可以监控的实例变量。
我觉得你最好的办法是自己创建一个回调类,并修改Parallel
的行为:
from math import sqrt
from collections import defaultdict
from joblib import Parallel, delayed
class CallBack(object):
completed = defaultdict(int)
def __init__(self, index, parallel):
self.index = index
self.parallel = parallel
def __call__(self, index):
CallBack.completed[self.parallel] += 1
print("done with {}".format(CallBack.completed[self.parallel]))
if self.parallel._original_iterable:
self.parallel.dispatch_next()
import joblib.parallel
joblib.parallel.CallBack = CallBack
if __name__ == "__main__":
print(Parallel(n_jobs=2)(delayed(sqrt)(i**2) for i in range(10)))
输出:
done with 1
done with 2
done with 3
done with 4
done with 5
done with 6
done with 7
done with 8
done with 9
done with 10
[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]
这样一来,每当一个任务完成时,你的回调就会被调用,而不是默认的那个。
再进一步,可以把整个过程封装成一个上下文管理器:
import contextlib
import joblib
from tqdm import tqdm
@contextlib.contextmanager
def tqdm_joblib(tqdm_object):
"""Context manager to patch joblib to report into tqdm progress bar given as argument"""
class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
def __call__(self, *args, **kwargs):
tqdm_object.update(n=self.batch_size)
return super().__call__(*args, **kwargs)
old_batch_callback = joblib.parallel.BatchCompletionCallBack
joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
try:
yield tqdm_object
finally:
joblib.parallel.BatchCompletionCallBack = old_batch_callback
tqdm_object.close()
这样你就可以像这样使用它,完成后不会留下修改过的代码:
from math import sqrt
from joblib import Parallel, delayed
with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar:
Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10))
我觉得这非常棒,而且看起来和tqdm与pandas的结合很相似。
为什么不能直接用 tqdm
呢?下面这个对我来说是有效的
from joblib import Parallel, delayed
from datetime import datetime
from tqdm import tqdm
def myfun(x):
return x**2
results = Parallel(n_jobs=8)(delayed(myfun)(i) for i in tqdm(range(1000))
100%|██████████| 1000/1000 [00:00<00:00, 10563.37it/s]