如何从启动事件中解析不可反序列化对象到多进程池中的函数?

0 投票
1 回答
43 浏览
提问于 2025-04-14 16:33

这是我的代码:

from multiprocessing import Pool
from functools import partial
from fastapi import FastAPI

clientObject = package.client() # object not pickle able
    
app = FastAPI()

def wrapper(func, retries, data_point):
# wrapper function to add a retry mechanism
    retry = 0
    while retry<retries:
        try:
            result = func(data_point)
        except Exception as err:
            result = err
            time.sleep(5) 
        else:
            break
    return result

def get_response(data_point):
# function use clientObject to get data from an Azure endpoint
    data_point = some_other_processes(data_point)
    ans = clientObject.process(data_point)
    return ans

def main(raw_data):
# main function where I use multiprocessing pool
    list_data_point = preprocess(raw_data)
    with Pool() as pool:
        wrapped_workload = partial(wrapper, 
                                   get_response,
                                   3)
        results = pool.map(wrapped_workload, list_data_point)
        pool.close()
        pool.join()
    return results


@app.post("/get_answer")
def get_answer(raw_data):
    processed_data = main(raw_data)
    return processed_data

上面的代码如果一开始就把 clientObject 声明为全局变量的话,运行得很好。

但是如果我把它存储在 app.state 中,像这样:

@app.on_event("start_up")
def start_connection():
    app.state.clientObject = package.client()

然后在 get_response 函数里这样访问它:

def get_response(data_point): 
     data_point = some_other_processes(data_point)
     ans = app.state.clientObject.process(data_point)
     return ans

就会报错:'State' 对象没有属性 clientObject。不过,app.state.clientObjectmain() 函数里还是可以用的。

另外,由于 clientObject 不能被序列化(pickleable),我不能把它作为参数传给 get_response(data_point, clientObject) 函数。

有没有办法让我在启动时初始化 clientObject,把它存储在一个变量里,并且从一个在多进程池中使用的函数访问它?(不把它声明为全局变量)

编辑这是我根据 Frank Yellin 的建议找到的解决方案:

def initialize_workers():
    global clientObject
    clientObject = package.client()

def get_response(data_point):
    global clientObject
    data_point = some_other_processes(data_point)
    ans = clientObject.process(data_point)
    return ans

def main(raw_data):
    list_data_point = preprocess(raw_data)
    with Pool(initializer=initialize_workers) as pool:
        wrapped_workload = partial(wrapper, 
                                   get_response,
                                   3)
        results = pool.map(wrapped_workload, list_data_point)
        pool.close()
        pool.join()
    return results

1 个回答

1

你的代码之所以不工作,是因为你提到的 clientObject 不是一个可以被“腌制”的对象。简单来说,它不能像其他对象那样从一个进程复制到另一个进程。每个进程都需要自己独立的 clientObject

所以你需要确保每个进程在启动时都能创建自己的 clientObject。你现在的代码确实是一种实现方式。

另外一种方法是在创建 Pool 时使用 initializerinitargs 参数。这可以指定一个函数和它的参数,每当一个新进程启动时就会调用这个函数。在这个函数里创建你的 clientObject 并把它存储在全局变量或者你的 app.state 中是个不错的选择。

撰写回答