使用多进程时出现PicklingError

28 投票
1 回答
27898 浏览
提问于 2025-04-16 23:19

我在使用 Pool.map_async()(还有 Pool.map())这个多进程模块时遇到了一些麻烦。我实现了一个并行循环的功能,只要传给 Pool.map_async 的是一个“普通”的函数,它就能正常工作。但是当这个函数是某个类的方法时,我就会遇到 PicklingError 错误:

cPickle.PicklingError: Can't pickle <type 'function'>: attribute lookup __builtin__.function failed

我主要用Python进行科学计算,所以对“序列化”这个概念不太熟悉,今天刚学到一点。我看过一些之前的回答,比如使用多进程的 Pool.map() 时无法序列化 <type 'instancemethod'>,但即使跟着回答中的链接,我也搞不定。

我的代码是为了模拟一个正态随机变量的向量,目的是利用多个核心来运行。请注意,这只是一个例子,可能在多个核心上运行并不划算。

import multiprocessing as mp
import scipy as sp
import scipy.stats as spstat

def parfor(func, args, static_arg = None, nWorkers = 8, chunksize = None):
    """
    Purpose: Evaluate function using Multiple cores.

    Input:
        func       - Function to evaluate in parallel
        arg        - Array of arguments to evaluate func(arg)  
        static_arg - The "static" argument (if any), i.e. the variables that are      constant in the evaluation of func.
        nWorkers   - Number of Workers to process computations.
    Output:
        func(i, static_arg) for i in args.
    
    """
    # Prepare arguments for func: Collect arguments with static argument (if any)
    if static_arg != None:
        arguments = [[arg] + static_arg for arg in list(args)]
    else:
        arguments = args
    
    # Initialize workers
    pool = mp.Pool(processes = nWorkers) 

    # Evaluate function
    result = pool.map_async(func, arguments, chunksize = chunksize)
    pool.close()
    pool.join()

    return sp.array(result.get()).flatten() 

# First test-function. Freeze location and scale for the Normal random variates generator.
# This returns a function that is a method of the class Norm_gen. Methods cannot be pickled
# so this will give an error.
def genNorm(loc, scale):
    def subfunc(a):
        return spstat.norm.rvs(loc = loc, scale = scale, size = a)
    return subfunc

# Second test-function. The same as above but does not return a method of a class. This is a "plain" function and can be 
# pickled
def test(fargs):
    x, a, b = fargs
    return spstat.norm.rvs(size = x, loc = a, scale = b)

# Try it out.
N = 1000000

# Set arguments to function. args1 = [1, 1, 1,... ,1], the purpose is just to generate a random variable of size 1 for each 
# element in the output vector.
args1 = sp.ones(N)
static_arg = [0, 1] # standarized normal.

# This gives the PicklingError
func = genNorm(*static_arg)
sim = parfor(func, args1, static_arg = None, nWorkers = 12, chunksize = None)

# This is OK:
func = test
sim = parfor(func, args1, static_arg = static_arg, nWorkers = 12, chunksize = None)

在回答中提到的链接里,Steven Bethard(快到最后的时候)建议使用 copy_reg 模块。他的代码是:

def _pickle_method(method):
    func_name = method.im_func.__name__
    obj = method.im_self
    cls = method.im_class
    return _unpickle_method, (func_name, obj, cls)

def _unpickle_method(func_name, obj, cls):
    for cls in cls.mro():
        try:
            func = cls.__dict__[func_name]
        except KeyError:
            pass
        else:
            break
    return func.__get__(obj, cls)

import copy_reg
import types

copy_reg.pickle(types.MethodType, _pickle_method, _unpickle_method)

我其实不太明白该怎么用这个模块。我唯一想到的就是把它放在我的代码之前,但这样并没有帮助。当然,简单的解决办法就是直接用能正常工作的函数,避免涉及 copy_reg。不过我更想弄明白怎么让 copy_reg 正常工作,这样就能充分利用多进程,而不必每次都绕过这个问题。

1 个回答

23

这里的问题其实不是“pickle”错误信息本身,而是一个概念上的问题:多进程(multiprocess)会把你的代码分叉到不同的“工作”进程中去执行它的魔法。

它通过无缝地将数据进行序列化和反序列化来在不同的进程之间传递数据(这就是使用pickle的部分)。

当传递的数据中有一个函数时,它会假设在被调用的进程中存在一个同名的函数,并且(我猜)会把函数名作为字符串传过去。因为函数是无状态的,被调用的工作进程只需用它收到的数据调用那个函数。 (Python中的函数不能通过pickle进行序列化,所以在主进程和工作进程之间只传递了函数的引用)

当你的函数是一个实例中的方法时——虽然在编写Python代码时它看起来和函数差不多,并且有一个“自动”的self变量,但它在底层并不相同。因为实例(对象)是有状态的。这意味着工作进程并没有你想在另一边调用的方法的对象的副本。

尝试将你的方法作为函数传递给map_async调用也行不通——因为多进程只是使用函数的引用,而不是实际的函数来传递。

所以,你应该(1)要么改变你的代码,以便传递一个函数——而不是方法——给工作进程,把对象保持的任何状态转换为新的参数来调用。 (2)为map_async调用创建一个“目标”函数,在工作进程那边重建所需的对象,然后在其中调用函数。Python中最简单的类本身是可以被pickle的,所以你可以在map_async调用中传递拥有该函数的对象——而“目标”函数会在工作进程那边调用适当的方法。

(2)可能听起来“困难”,但实际上可能就是这样——除非你的对象的类不能被pickle:

import types

def target(object, *args, **kw):
    method_name = args[0]
    return getattr(object, method_name)(*args[1:])
(...)    
#And add these 3 lines prior to your map_async call:


    # Evaluate function
    if isinstance (func, types.MethodType):
        arguments.insert(0, func.__name__)
        func = target
    result = pool.map_async(func, arguments, chunksize = chunksize)

*免责声明:我还没有测试过这个

撰写回答