如何正确(或最佳)地子类化Python集合类,并添加新的实例变量?
我正在实现一个几乎和集合(set)一样的对象,但需要多一个实例变量,所以我打算在内置的集合对象基础上进行扩展。请问有什么好的方法可以确保在复制我的对象时,这个变量的值也能被复制?
在使用旧的集合模块时,以下代码运行得很好:
import sets
class Fooset(sets.Set):
def __init__(self, s = []):
sets.Set.__init__(self, s)
if isinstance(s, Fooset):
self.foo = s.foo
else:
self.foo = 'default'
f = Fooset([1,2,4])
f.foo = 'bar'
assert( (f | f).foo == 'bar')
但在使用内置的集合模块时,这个方法就不行了。
我能想到的唯一解决办法就是重写每一个返回复制集合对象的方法……这样的话,我还不如不去扩展集合对象。难道没有标准的方法可以做到这一点吗?
(为了更清楚,以下代码是不可行的(断言失败):
class Fooset(set):
def __init__(self, s = []):
set.__init__(self, s)
if isinstance(s, Fooset):
self.foo = s.foo
else:
self.foo = 'default'
f = Fooset([1,2,4])
f.foo = 'bar'
assert( (f | f).foo == 'bar')
)
9 个回答
可惜的是,集合(set)并不遵循规则,创建新的集合对象时并不会调用 __new__
方法,尽管它们保持了类型。这显然是 Python 的一个bug(问题 #1721812,在2.x版本中不会修复)。你不应该在没有调用创建该类型对象的 type
的情况下,得到类型为 X 的对象!如果 set.__or__
不会调用 __new__
,那么它就应该返回 set
对象,而不是子类对象。
不过,实际上,参考上面 nosklo 的帖子,你最初的行为是没有意义的。Set.__or__
操作符不应该重复使用任何源对象来构造结果,它应该创建一个新的对象,这样它的 foo
应该是 "default"
!
所以,实际上,任何这样做的人都应该重载这些操作符,以便知道使用的是哪个 foo
的副本。如果它不依赖于被组合的 Foosets,你可以将其设置为类的默认值,这样它就会被尊重,因为新对象认为自己是子类类型。
我的意思是,如果你这样做,你的例子就能工作,算是可以:
class Fooset(set):
foo = 'default'
def __init__(self, s = []):
if isinstance(s, Fooset):
self.foo = s.foo
f = Fooset([1,2,5])
assert (f|f).foo == 'default'
我觉得推荐的做法是,不要直接从内置的 set
类继承,而是应该使用 抽象基类 Set
,这个类在 collections.abc 模块里可以找到。
使用 ABC Set 可以让你免费获得一些方法,这样你只需要定义 __contains__()
、__len__()
和 __iter__()
就能创建一个简单的 Set 类。如果你想要一些更好用的集合方法,比如 intersection()
(交集)和 difference()
(差集),那么你可能需要自己去封装这些方法。
这是我尝试写的代码(这个例子是类似于 frozenset 的,但你也可以从 MutableSet
继承,来得到一个可变的版本):
from collections.abc import Set, Hashable
class CustomSet(Set, Hashable):
"""An example of a custom frozenset-like object using
Abstract Base Classes.
"""
__hash__ = Set._hash
wrapped_methods = ('difference',
'intersection',
'symetric_difference',
'union',
'copy')
def __repr__(self):
return "CustomSet({0})".format(list(self._set))
def __new__(cls, iterable=None):
selfobj = super(CustomSet, cls).__new__(CustomSet)
selfobj._set = frozenset() if iterable is None else frozenset(iterable)
for method_name in cls.wrapped_methods:
setattr(selfobj, method_name, cls._wrap_method(method_name, selfobj))
return selfobj
@classmethod
def _wrap_method(cls, method_name, obj):
def method(*args, **kwargs):
result = getattr(obj._set, method_name)(*args, **kwargs)
return CustomSet(result)
return method
def __getattr__(self, attr):
"""Make sure that we get things like issuperset() that aren't provided
by the mix-in, but don't need to return a new set."""
return getattr(self._set, attr)
def __contains__(self, item):
return item in self._set
def __len__(self):
return len(self._set)
def __iter__(self):
return iter(self._set)
我最喜欢的方式是给内置集合的方法加上一层包装:
class Fooset(set):
def __init__(self, s=(), foo=None):
super(Fooset,self).__init__(s)
if foo is None and hasattr(s, 'foo'):
foo = s.foo
self.foo = foo
@classmethod
def _wrap_methods(cls, names):
def wrap_method_closure(name):
def inner(self, *args):
result = getattr(super(cls, self), name)(*args)
if isinstance(result, set) and not hasattr(result, 'foo'):
result = cls(result, foo=self.foo)
return result
inner.fn_name = name
setattr(cls, name, inner)
for name in names:
wrap_method_closure(name)
Fooset._wrap_methods(['__ror__', 'difference_update', '__isub__',
'symmetric_difference', '__rsub__', '__and__', '__rand__', 'intersection',
'difference', '__iand__', 'union', '__ixor__',
'symmetric_difference_update', '__or__', 'copy', '__rxor__',
'intersection_update', '__xor__', '__ior__', '__sub__',
])
基本上和你自己答案里做的差不多,但代码行数更少。如果你想对列表和字典做同样的事情,也很容易加一个元类。