Python:从子类调用父类方法时超出最大递归深度
我有一个继承关系的三个类:Baz 继承自 Bar,Bar 又继承自 Foo。我想在 Foo 类里定义一个方法,当在子类中调用时,能够返回父类的输出,并加上子类自己的内容。这是我想要的输出:
>>> foo = Foo()
>>> bar = Bar()
>>> baz = Baz()
>>> print foo.get_defining_fields()
['in Foo']
>>> print bar.get_defining_fields()
['in Foo', 'in Bar']
>>> print baz.get_defining_fields()
['in Foo', 'in Bar', 'in Baz']
问题是,我对在子类调用的父类方法中使用 super()
的理解有些错误,或者是其他关于子类的一些细节。我这部分代码运行得很好:
>>> print foo.get_defining_fields()
['in Foo']
但是 bar.get_defining_fields()
产生了一个无限循环,它一直在自己运行,直到抛出一个 RuntimeError
,而不是像我想要的那样调用 foo.get_defining_fields
并停止。
这是代码。
class Foo(object):
def _defining_fields(self):
return ['in Foo']
def get_defining_fields(self):
if self.__class__ == Foo:
# Top of the chain, don't call super()
return self._defining_fields()
return super(self.__class__, self).get_defining_fields() + self._defining_fields()
class Bar(Foo):
def _defining_fields(self):
return ['in Bar']
class Baz(Bar):
def _defining_fields(self):
return ['in Baz']
所以 get_defining_fields
是在父类中定义的,而其中的 super()
调用使用了 self.__class__
,试图将正确的子类名称传递给每个子类中的调用。当在 Bar 中调用时,它解析为 super(Bar, self).get_defining_fields()
,这样从 foo.get_defining_fields()
返回的列表就会被加到 far.get_defining_fields()
返回的列表前面。
这可能是个简单的错误,如果你对 Python 的继承机制和 super()
的内部工作原理理解得很透彻,但显然我并没有,所以如果有人能告诉我正确的做法,我会很感激。
编辑:根据 Daniel Roseman 的回答,我尝试将 super()
的调用替换为这个形式:return super(Foo, self).get_defining_fields() + self._defining_fields()
现在无限递归不再发生,但在调用 bar.get_defining_fields()
时我得到了一个不同的错误:
AttributeError: 'super' object has no attribute 'get_defining_fields'
还有其他问题。
编辑:是的,我终于明白我遗漏了什么。把 Daniel 更新的回答标记为接受的答案。
1 个回答
问题出在这里:
return super(self.__class__, self)...
self.__class__
总是指向 当前 的具体类。所以在 Bar 类中,它指的是 Bar,而在 Baz 类中,它指的是 Baz。因此,Bar 调用了父类的方法,但 self.__class__
仍然指向 Bar,而不是 Foo。这就导致了无限递归。
这就是为什么你必须在使用 super 时,明确 指定类名。像这样:
return super(Foo, self)...
编辑 对,因为只有顶层类定义了 get_defining_fields
。
你这样做其实是错的。这根本不是 super
的工作。我觉得你可以尝试通过 self.__class__.__mro__
来获取更好的结果,这样可以访问父类的方法解析顺序:
class Foo(object):
def _defining_fields(self):
return ['in Foo']
def get_defining_fields(self):
fields = []
for cls in self.__class__.__mro__:
if hasattr(cls, '_defining_fields'):
fields.append(cls._defining_fields(self))
return fields
¬