如何重写Python对象的copy/deepcopy操作?

143 投票
10 回答
111143 浏览
提问于 2025-04-15 14:42

我明白在复制模块中,copydeepcopy 的区别。我之前成功使用过 copy.copycopy.deepcopy,但这是我第一次真正去重载 __copy____deepcopy__ 方法。我已经在网上搜索过,也查看了内置的 Python 模块,想找找 __copy____deepcopy__ 函数的例子(比如 sets.pydecimal.pyfractions.py),但我还是不太确定自己理解得对不对。

这是我的场景:

我有一个配置对象。最开始,我会创建一个配置对象,并给它一组默认值。这个配置会被传递给多个其他对象(这样可以确保所有对象都从相同的配置开始)。但是,一旦用户开始交互,每个对象都需要独立调整自己的配置,而不影响其他对象的配置(这就意味着我需要对初始配置进行深复制,以便传递给其他对象)。

这是一个示例对象:

class ChartConfig(object):

    def __init__(self):

        #Drawing properties (Booleans/strings)
        self.antialiased = None
        self.plot_style = None
        self.plot_title = None
        self.autoscale = None

        #X axis properties (strings/ints)
        self.xaxis_title = None
        self.xaxis_tick_rotation = None
        self.xaxis_tick_align = None

        #Y axis properties (strings/ints)
        self.yaxis_title = None
        self.yaxis_tick_rotation = None
        self.yaxis_tick_align = None

        #A list of non-primitive objects
        self.trace_configs = []

    def __copy__(self):
        pass

    def __deepcopy__(self, memo):
        pass 

那么,如何正确实现这个对象的 copydeepcopy 方法,以确保 copy.copycopy.deepcopy 能给我正确的行为呢?

10 个回答

21

根据Peter的精彩回答,如果你想实现一个自定义的深拷贝,而且只想对默认的实现做最小的修改(比如只改一个字段),可以参考以下代码:

class Foo(object):
    def __deepcopy__(self, memo):
        deepcopy_method = self.__deepcopy__
        self.__deepcopy__ = None
        cp = deepcopy(self, memo)
        self.__deepcopy__ = deepcopy_method
        cp.__deepcopy__ = deepcopy_method

        # custom treatments
        # for instance: cp.id = None

        return cp

补充说明:这种方法有一个局限性,正如Igor Kozyrenko所指出的,拷贝对象的__deepcopy__方法仍然会和原始对象绑定在一起,所以如果你对一个拷贝再进行拷贝,结果其实还是原始对象的拷贝。或许有办法可以重新绑定__deepcopy__cp,而不是仅仅用cp.__deepcopy__ = deepcopy_method来赋值。

139

结合Alex Martelli的回答和Rob Young的评论,你可以得到以下代码:

from copy import copy, deepcopy

class A(object):
    def __init__(self):
        print 'init'
        self.v = 10
        self.z = [2,3,4]

    def __copy__(self):
        cls = self.__class__
        result = cls.__new__(cls)
        result.__dict__.update(self.__dict__)
        return result

    def __deepcopy__(self, memo):
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():
            setattr(result, k, deepcopy(v, memo))
        return result

a = A()
a.v = 11
b1, b2 = copy(a), deepcopy(a)
a.v = 12
a.z.append(5)
print b1.v, b1.z
print b2.v, b2.z

运行后会输出

init
11 [2, 3, 4, 5]
11 [2, 3, 4]

这里的 __deepcopy__ 方法会填充 memo 字典,以避免在对象的某个成员引用了对象本身的情况下进行多余的复制。

94

关于自定义的建议在文档页面的最后部分:

类可以使用与控制复制相同的接口来控制序列化(即将对象转换为可存储或传输的格式)。想了解这些方法,可以查看pickle模块的描述。复制模块并不使用copy_reg注册模块。

为了让一个类定义自己的复制实现,它可以定义特殊的方法__copy__()__deepcopy__()。前者用于实现浅复制操作;调用时不需要额外的参数。后者用于实现深复制操作;调用时会传入一个参数,即备忘录字典。如果__deepcopy__()的实现需要对某个组件进行深复制,它应该调用deepcopy()函数,第一个参数是组件,第二个参数是备忘录字典。

因为你似乎不关心序列化的自定义,所以定义__copy____deepcopy__绝对是适合你的方法。

具体来说,__copy__(浅复制)在你的情况下相对简单...:

def __copy__(self):
  newone = type(self)()
  newone.__dict__.update(self.__dict__)
  return newone

__deepcopy__也类似(同样接受一个memo参数),但在返回之前,它需要对任何需要深复制的属性self.foo调用self.foo = deepcopy(self.foo, memo)(基本上是那些容器类型的属性,比如列表、字典,以及通过它们的__dict__持有其他内容的非基本对象)。

撰写回答