如何正确(或最佳)地子类化Python集合类,并添加新的实例变量?

18 投票
9 回答
16619 浏览
提问于 2025-04-15 11:18

我正在实现一个几乎和集合(set)一样的对象,但需要多一个实例变量,所以我打算在内置的集合对象基础上进行扩展。请问有什么好的方法可以确保在复制我的对象时,这个变量的值也能被复制?

在使用旧的集合模块时,以下代码运行得很好:

import sets
class Fooset(sets.Set):
    def __init__(self, s = []):
        sets.Set.__init__(self, s)
        if isinstance(s, Fooset):
            self.foo = s.foo
        else:
            self.foo = 'default'
f = Fooset([1,2,4])
f.foo = 'bar'
assert( (f | f).foo == 'bar')

但在使用内置的集合模块时,这个方法就不行了。

我能想到的唯一解决办法就是重写每一个返回复制集合对象的方法……这样的话,我还不如不去扩展集合对象。难道没有标准的方法可以做到这一点吗?

(为了更清楚,以下代码是可行的(断言失败):

class Fooset(set):
    def __init__(self, s = []):
        set.__init__(self, s)
        if isinstance(s, Fooset):
            self.foo = s.foo
        else:
            self.foo = 'default'

f = Fooset([1,2,4])
f.foo = 'bar'
assert( (f | f).foo == 'bar')

9 个回答

4

可惜的是,集合(set)并不遵循规则,创建新的集合对象时并不会调用 __new__ 方法,尽管它们保持了类型。这显然是 Python 的一个bug(问题 #1721812,在2.x版本中不会修复)。你不应该在没有调用创建该类型对象的 type 的情况下,得到类型为 X 的对象!如果 set.__or__ 不会调用 __new__,那么它就应该返回 set 对象,而不是子类对象。

不过,实际上,参考上面 nosklo 的帖子,你最初的行为是没有意义的。Set.__or__ 操作符不应该重复使用任何源对象来构造结果,它应该创建一个新的对象,这样它的 foo 应该是 "default"

所以,实际上,任何这样做的人都应该重载这些操作符,以便知道使用的是哪个 foo 的副本。如果它不依赖于被组合的 Foosets,你可以将其设置为类的默认值,这样它就会被尊重,因为新对象认为自己是子类类型。

我的意思是,如果你这样做,你的例子就能工作,算是可以:

class Fooset(set):
  foo = 'default'
  def __init__(self, s = []):
    if isinstance(s, Fooset):
      self.foo = s.foo

f = Fooset([1,2,5])
assert (f|f).foo == 'default'
11

我觉得推荐的做法是,不要直接从内置的 set 类继承,而是应该使用 抽象基类 Set,这个类在 collections.abc 模块里可以找到。

使用 ABC Set 可以让你免费获得一些方法,这样你只需要定义 __contains__()__len__()__iter__() 就能创建一个简单的 Set 类。如果你想要一些更好用的集合方法,比如 intersection()(交集)和 difference()(差集),那么你可能需要自己去封装这些方法。

这是我尝试写的代码(这个例子是类似于 frozenset 的,但你也可以从 MutableSet 继承,来得到一个可变的版本):

from collections.abc import Set, Hashable

class CustomSet(Set, Hashable):
    """An example of a custom frozenset-like object using
    Abstract Base Classes.
    """
    __hash__ = Set._hash

    wrapped_methods = ('difference',
                       'intersection',
                       'symetric_difference',
                       'union',
                       'copy')

    def __repr__(self):
        return "CustomSet({0})".format(list(self._set))

    def __new__(cls, iterable=None):
        selfobj = super(CustomSet, cls).__new__(CustomSet)
        selfobj._set = frozenset() if iterable is None else frozenset(iterable)
        for method_name in cls.wrapped_methods:
            setattr(selfobj, method_name, cls._wrap_method(method_name, selfobj))
        return selfobj

    @classmethod
    def _wrap_method(cls, method_name, obj):
        def method(*args, **kwargs):
            result = getattr(obj._set, method_name)(*args, **kwargs)
            return CustomSet(result)
        return method

    def __getattr__(self, attr):
        """Make sure that we get things like issuperset() that aren't provided
        by the mix-in, but don't need to return a new set."""
        return getattr(self._set, attr)

    def __contains__(self, item):
        return item in self._set

    def __len__(self):
        return len(self._set)

    def __iter__(self):
        return iter(self._set)
21

我最喜欢的方式是给内置集合的方法加上一层包装:

class Fooset(set):
    def __init__(self, s=(), foo=None):
        super(Fooset,self).__init__(s)
        if foo is None and hasattr(s, 'foo'):
            foo = s.foo
        self.foo = foo



    @classmethod
    def _wrap_methods(cls, names):
        def wrap_method_closure(name):
            def inner(self, *args):
                result = getattr(super(cls, self), name)(*args)
                if isinstance(result, set) and not hasattr(result, 'foo'):
                    result = cls(result, foo=self.foo)
                return result
            inner.fn_name = name
            setattr(cls, name, inner)
        for name in names:
            wrap_method_closure(name)

Fooset._wrap_methods(['__ror__', 'difference_update', '__isub__', 
    'symmetric_difference', '__rsub__', '__and__', '__rand__', 'intersection',
    'difference', '__iand__', 'union', '__ixor__', 
    'symmetric_difference_update', '__or__', 'copy', '__rxor__',
    'intersection_update', '__xor__', '__ior__', '__sub__',
])

基本上和你自己答案里做的差不多,但代码行数更少。如果你想对列表和字典做同样的事情,也很容易加一个元类。

撰写回答