处理使用copy_reg的classmethod序列化问题

9 投票
3 回答
4120 浏览
提问于 2025-04-16 14:23

我在处理多进程的时候遇到了一个“序列化错误”:

 from multiprocessing import Pool

 def test_func(x):
     return x**2

 class Test:
     @classmethod
     def func(cls, x):
         return x**2

 def mp_run(n, func, args):
     return Pool(n).map(func, args)

 if __name__ == '__main__':
     args = range(1,6)

     print mp_run(5, test_func, args)
     # [1, 4, 9, 16, 25]

     print mp_run(5, Test.func, args)
     """
     Exception in thread Thread-3:
     Traceback (most recent call last):
       File "/usr/lib64/python2.6/threading.py", line 532, in __bootstrap_inner
         self.run()
       File "/usr/lib64/python2.6/threading.py", line 484, in run
         self.__target(*self.__args, **self.__kwargs)
       File "/usr/lib64/python2.6/multiprocessing/pool.py", line 225, in _handle_tasks
         put(task)
     PicklingError: Can't pickle <type 'instancemethod'>: attribute lookup __builtin__.instancemethod failed
     """

我找到了一条很有用的讨论 在这里,这个解决方案对那些带有“self”的实例方法非常有效,但我在把这个方法应用到@classmethod时遇到了问题:

 def _pickle_method(method):
     func_name = method.im_func.__name__
     obj = method.im_self
     cls = method.im_class
     return _unpickle_method, (func_name, obj, cls)

 def _unpickle_method(func_name, obj, cls):
     try:
         for cls in cls.mro():
             try:
                 func = cls.__dict__[func_name]
             except KeyError:
                 pass
             else:
                 break
     except AttributeError:
         func = cls.__dict__[func_name]
     return func.__get__(obj, cls)

 copy_reg.pickle(MethodType, _pickle_method, _unpickle_method)
 new_func = pickle.loads(pickle.dumps(Test.func))
 """
 Traceback (most recent call last):
 File "test3.py", line 45, in <module>
   new_func = pickle.loads(pickle.dumps(Test.func))
 File "/usr/lib64/python2.6/pickle.py", line 1366, in dumps
   Pickler(file, protocol).dump(obj)
 File "/usr/lib64/python2.6/pickle.py", line 224, in dump
   self.save(obj)
 File "/usr/lib64/python2.6/pickle.py", line 331, in save
   self.save_reduce(obj=obj, *rv)
 File "/usr/lib64/python2.6/pickle.py", line 401, in save_reduce
   save(args)
 File "/usr/lib64/python2.6/pickle.py", line 286, in save
   f(self, obj) # Call unbound method with explicit self
 File "/usr/lib64/python2.6/pickle.py", line 562, in save_tuple
   save(element)
 File "/usr/lib64/python2.6/pickle.py", line 286, in save
   f(self, obj) # Call unbound method with explicit self
 File "/usr/lib64/python2.6/pickle.py", line 748, in save_global
   (obj, module, name)) 
pickle.PicklingError: Can't pickle <type 'classobj'>: it's not found as __builtin__.classobj
"""

有没有办法改几行代码,让它也能适用于类方法呢?

3 个回答

0

与其直接从 _pickle_method 返回实际的类对象,不如返回一个字符串,这个字符串可以在反序列化时用来导入这个类。然后在 _unpickle_method 中执行这个导入。

4

下面这个解决方案现在也能正确处理类的方法了。如果还有什么遗漏的地方,请告诉我。

def _pickle_method(method):
    """
    Pickle methods properly, including class methods.
    """
    func_name = method.im_func.__name__
    obj = method.im_self
    cls = method.im_class
    if isinstance(cls, type):
        # handle classmethods differently
        cls = obj
        obj = None
    if func_name.startswith('__') and not func_name.endswith('__'):
        #deal with mangled names
        cls_name = cls.__name__.lstrip('_')
        func_name = '_%s%s' % (cls_name, func_name)

    return _unpickle_method, (func_name, obj, cls)

def _unpickle_method(func_name, obj, cls):
    """
    Unpickle methods properly, including class methods.
    """
    if obj is None:
        return cls.__dict__[func_name].__get__(obj, cls)
    for cls in cls.__mro__:
        try:
            func = cls.__dict__[func_name]
        except KeyError:
            pass
        else:
            break
    return func.__get__(obj, cls)
4

我修改了这个方法,让它可以和类方法一起使用。下面是代码。

import copy_reg
import types

def _pickle_method(method):
    func_name = method.im_func.__name__
    obj = method.im_self
    cls = method.im_class
    if func_name.startswith('__') and not func_name.endswith('__'):
        #deal with mangled names
        cls_name = cls.__name__.lstrip('_')
        func_name = '_%s%s' % (cls_name, func_name)
    return _unpickle_method, (func_name, obj, cls)

def _unpickle_method(func_name, obj, cls):
    if obj and func_name in obj.__dict__:
        cls, obj = obj, None # if func_name is classmethod
    for cls in cls.__mro__:
        try:
            func = cls.__dict__[func_name]
        except KeyError:
            pass
        else:
            break
    return func.__get__(obj, cls)

copy_reg.pickle(types.MethodType, _pickle_method, _unpickle_method)

撰写回答