如何从启动事件中解析不可反序列化对象到多进程池中的函数?
这是我的代码:
from multiprocessing import Pool
from functools import partial
from fastapi import FastAPI
clientObject = package.client() # object not pickle able
app = FastAPI()
def wrapper(func, retries, data_point):
# wrapper function to add a retry mechanism
retry = 0
while retry<retries:
try:
result = func(data_point)
except Exception as err:
result = err
time.sleep(5)
else:
break
return result
def get_response(data_point):
# function use clientObject to get data from an Azure endpoint
data_point = some_other_processes(data_point)
ans = clientObject.process(data_point)
return ans
def main(raw_data):
# main function where I use multiprocessing pool
list_data_point = preprocess(raw_data)
with Pool() as pool:
wrapped_workload = partial(wrapper,
get_response,
3)
results = pool.map(wrapped_workload, list_data_point)
pool.close()
pool.join()
return results
@app.post("/get_answer")
def get_answer(raw_data):
processed_data = main(raw_data)
return processed_data
上面的代码如果一开始就把 clientObject
声明为全局变量的话,运行得很好。
但是如果我把它存储在 app.state
中,像这样:
@app.on_event("start_up")
def start_connection():
app.state.clientObject = package.client()
然后在 get_response
函数里这样访问它:
def get_response(data_point):
data_point = some_other_processes(data_point)
ans = app.state.clientObject.process(data_point)
return ans
就会报错:'State' 对象没有属性 clientObject
。不过,app.state.clientObject
在 main()
函数里还是可以用的。
另外,由于 clientObject
不能被序列化(pickleable),我不能把它作为参数传给 get_response(data_point, clientObject)
函数。
有没有办法让我在启动时初始化 clientObject
,把它存储在一个变量里,并且从一个在多进程池中使用的函数访问它?(不把它声明为全局变量)
编辑:这是我根据 Frank Yellin 的建议找到的解决方案:
def initialize_workers():
global clientObject
clientObject = package.client()
def get_response(data_point):
global clientObject
data_point = some_other_processes(data_point)
ans = clientObject.process(data_point)
return ans
def main(raw_data):
list_data_point = preprocess(raw_data)
with Pool(initializer=initialize_workers) as pool:
wrapped_workload = partial(wrapper,
get_response,
3)
results = pool.map(wrapped_workload, list_data_point)
pool.close()
pool.join()
return results
1 个回答
1
你的代码之所以不工作,是因为你提到的 clientObject
不是一个可以被“腌制”的对象。简单来说,它不能像其他对象那样从一个进程复制到另一个进程。每个进程都需要自己独立的 clientObject
。
所以你需要确保每个进程在启动时都能创建自己的 clientObject
。你现在的代码确实是一种实现方式。
另外一种方法是在创建 Pool
时使用 initializer
和 initargs
参数。这可以指定一个函数和它的参数,每当一个新进程启动时就会调用这个函数。在这个函数里创建你的 clientObject
并把它存储在全局变量或者你的 app.state
中是个不错的选择。