上下文管理器与多进程池
假设你正在使用一个叫做 multiprocessing.Pool
的对象,并且在创建这个对象的时候,你用 initializer
参数传入了一个初始化函数,这个函数会在全局命名空间中创建一个资源。假设这个资源有一个上下文管理器。你该如何管理这个上下文管理的资源的生命周期呢?也就是说,这个资源需要在进程的整个生命周期内存在,但在结束时又要正确清理掉。
到目前为止,我的做法大致是这样的:
resource_cm = None
resource = None
def _worker_init(args):
global resource
resource_cm = open_resource(args)
resource = resource_cm.__enter__()
从这里开始,池中的进程就可以使用这个资源了。到这里为止一切都很好。但是,处理清理工作就有点棘手,因为 multiprocessing.Pool
类并没有提供 destructor
或 deinitializer
这样的参数。
我想到的一个办法是使用 atexit
模块,在初始化函数中注册清理工作。大概是这样的:
def _worker_init(args):
global resource
resource_cm = open_resource(args)
resource = resource_cm.__enter__()
def _clean_up():
resource_cm.__exit__()
import atexit
atexit.register(_clean_up)
这样做是否可行?有没有更简单的方法呢?
补充说明:atexit
似乎不太管用。至少在我上面使用的方式下,它没有效果,所以目前我还没有找到解决这个问题的方法。
3 个回答
这是我想到的一个解决方案。它使用了billiard,这是Python的多进程包的一个分支。这个方案需要用到一个私有的API Worker._ensure_messages_consumed
,所以我不建议在生产环境中使用这个方案。我只是为了一个小项目需要这个,所以对我来说就足够了。使用这个方案要自己承担风险。
from billiard import pool
from billiard.pool import Pool, Worker
class SafeWorker(Worker):
# this function is called just before a worker process exits
def _ensure_messages_consumed(self, *args, **kwargs):
# Not necessary, but you can move `Pool.initializer` logic here if you want.
out = super()._ensure_messages_consumed(*args, **kwargs)
# Do clean up work here
return out
class SafePool(Pool):
Worker = SafeWorker
我尝试的另一个方案是把我的清理逻辑实现为一个信号处理器,但这个方法不行,因为multiprocessing
和billiard
都使用exit()
来结束它们的工作进程。我不太确定atexit
是怎么工作的,但这可能就是这个方法不奏效的原因。
你可以创建一个新的类,继承自 Process
,并重写它的 run()
方法,这样在程序退出之前就可以进行一些清理工作。接着,你还需要创建一个新的类,继承自 Pool
,让它使用你刚刚创建的那个新的进程类:
from multiprocessing import Process
from multiprocessing.pool import Pool
class SafeProcess(Process):
""" Process that will cleanup before exit """
def run(self, *args, **kw):
result = super().run(*args, **kw)
# cleanup however you want here
return result
class SafePool(Pool):
Process = SafeProcess
pool = SafePool(4) # use it as standard Pool
首先,这个问题真不错!我在multiprocessing
的代码里稍微研究了一下,发现了一种解决方法:
当你启动一个multiprocessing.Pool
时,内部会为池中的每个成员创建一个multiprocessing.Process
对象。当这些子进程启动时,它们会调用一个叫_bootstrap
的函数,内容大致如下:
def _bootstrap(self):
from . import util
global _current_process
try:
# ... (stuff we don't care about)
util._finalizer_registry.clear()
util._run_after_forkers()
util.info('child process calling self.run()')
try:
self.run()
exitcode = 0
finally:
util._exit_function()
# ... (more stuff we don't care about)
run
方法实际上是执行你给Process
对象的target
。对于Pool
进程来说,这个方法包含一个长时间运行的循环,等待工作项通过内部队列进入。对我们来说,最有趣的是self.run
之后发生的事情:会调用util._exit_function()
。
结果发现,这个函数会进行一些清理工作,听起来正是你想要的:
def _exit_function(info=info, debug=debug, _run_finalizers=_run_finalizers,
active_children=active_children,
current_process=current_process):
# NB: we hold on to references to functions in the arglist due to the
# situation described below, where this function is called after this
# module's globals are destroyed.
global _exiting
info('process shutting down')
debug('running all "atexit" finalizers with priority >= 0') # Very interesting!
_run_finalizers(0)
这是_run_finalizers
的文档说明:
def _run_finalizers(minpriority=None):
'''
Run all finalizers whose exit priority is not None and at least minpriority
Finalizers with highest priority are called first; finalizers with
the same priority will be called in reverse order of creation.
'''
这个方法实际上会遍历一个最终处理器回调的列表并执行它们:
items = [x for x in _finalizer_registry.items() if f(x)]
items.sort(reverse=True)
for key, finalizer in items:
sub_debug('calling %s', finalizer)
try:
finalizer()
except Exception:
import traceback
traceback.print_exc()
太好了。那么我们怎么才能进入_finalizer_registry
呢?在multiprocessing.util
中有一个未记录的对象叫Finalize
,它负责将回调添加到注册表中:
class Finalize(object):
'''
Class which supports object finalization using weakrefs
'''
def __init__(self, obj, callback, args=(), kwargs=None, exitpriority=None):
assert exitpriority is None or type(exitpriority) is int
if obj is not None:
self._weakref = weakref.ref(obj, self)
else:
assert exitpriority is not None
self._callback = callback
self._args = args
self._kwargs = kwargs or {}
self._key = (exitpriority, _finalizer_counter.next())
self._pid = os.getpid()
_finalizer_registry[self._key] = self # That's what we're looking for!
好了,把这些内容整合成一个例子:
import multiprocessing
from multiprocessing.util import Finalize
resource_cm = None
resource = None
class Resource(object):
def __init__(self, args):
self.args = args
def __enter__(self):
print("in __enter__ of %s" % multiprocessing.current_process())
return self
def __exit__(self, *args, **kwargs):
print("in __exit__ of %s" % multiprocessing.current_process())
def open_resource(args):
return Resource(args)
def _worker_init(args):
global resource
print("calling init")
resource_cm = open_resource(args)
resource = resource_cm.__enter__()
# Register a finalizer
Finalize(resource, resource.__exit__, exitpriority=16)
def hi(*args):
print("we're in the worker")
if __name__ == "__main__":
pool = multiprocessing.Pool(initializer=_worker_init, initargs=("abc",))
pool.map(hi, range(pool._processes))
pool.close()
pool.join()
输出结果:
calling init
in __enter__ of <Process(PoolWorker-1, started daemon)>
calling init
calling init
in __enter__ of <Process(PoolWorker-2, started daemon)>
in __enter__ of <Process(PoolWorker-3, started daemon)>
calling init
in __enter__ of <Process(PoolWorker-4, started daemon)>
we're in the worker
we're in the worker
we're in the worker
we're in the worker
in __exit__ of <Process(PoolWorker-1, started daemon)>
in __exit__ of <Process(PoolWorker-2, started daemon)>
in __exit__ of <Process(PoolWorker-3, started daemon)>
in __exit__ of <Process(PoolWorker-4, started daemon)>
正如你所看到的,当我们调用join()
来合并池时,__exit__
会在所有工作进程中被调用。