Python:从子类调用父类方法时超出最大递归深度

3 投票
1 回答
2232 浏览
提问于 2025-04-16 20:44

我有一个继承关系的三个类:Baz 继承自 Bar,Bar 又继承自 Foo。我想在 Foo 类里定义一个方法,当在子类中调用时,能够返回父类的输出,并加上子类自己的内容。这是我想要的输出:

>>> foo = Foo()
>>> bar = Bar()
>>> baz = Baz()
>>> print foo.get_defining_fields()
['in Foo']
>>> print bar.get_defining_fields()
['in Foo', 'in Bar']
>>> print baz.get_defining_fields()
['in Foo', 'in Bar', 'in Baz']

问题是,我对在子类调用的父类方法中使用 super() 的理解有些错误,或者是其他关于子类的一些细节。我这部分代码运行得很好:

>>> print foo.get_defining_fields()
['in Foo']

但是 bar.get_defining_fields() 产生了一个无限循环,它一直在自己运行,直到抛出一个 RuntimeError,而不是像我想要的那样调用 foo.get_defining_fields 并停止。

这是代码。

class Foo(object):
    def _defining_fields(self):
        return ['in Foo']

    def get_defining_fields(self):
        if self.__class__ == Foo:
            # Top of the chain, don't call super()
            return self._defining_fields()
        return super(self.__class__, self).get_defining_fields() + self._defining_fields()

class Bar(Foo):
    def _defining_fields(self):
        return ['in Bar']

class Baz(Bar):
    def _defining_fields(self):
        return ['in Baz']

所以 get_defining_fields 是在父类中定义的,而其中的 super() 调用使用了 self.__class__,试图将正确的子类名称传递给每个子类中的调用。当在 Bar 中调用时,它解析为 super(Bar, self).get_defining_fields(),这样从 foo.get_defining_fields() 返回的列表就会被加到 far.get_defining_fields() 返回的列表前面。

这可能是个简单的错误,如果你对 Python 的继承机制和 super() 的内部工作原理理解得很透彻,但显然我并没有,所以如果有人能告诉我正确的做法,我会很感激。


编辑:根据 Daniel Roseman 的回答,我尝试将 super() 的调用替换为这个形式:return super(Foo, self).get_defining_fields() + self._defining_fields()

现在无限递归不再发生,但在调用 bar.get_defining_fields() 时我得到了一个不同的错误:

AttributeError: 'super' object has no attribute 'get_defining_fields'

还有其他问题。


编辑:是的,我终于明白我遗漏了什么。把 Daniel 更新的回答标记为接受的答案。

1 个回答

8

问题出在这里:

 return super(self.__class__, self)...

self.__class__ 总是指向 当前 的具体类。所以在 Bar 类中,它指的是 Bar,而在 Baz 类中,它指的是 Baz。因此,Bar 调用了父类的方法,但 self.__class__ 仍然指向 Bar,而不是 Foo。这就导致了无限递归。

这就是为什么你必须在使用 super 时,明确 指定类名。像这样:

return super(Foo, self)...

编辑 对,因为只有顶层类定义了 get_defining_fields

你这样做其实是错的。这根本不是 super 的工作。我觉得你可以尝试通过 self.__class__.__mro__ 来获取更好的结果,这样可以访问父类的方法解析顺序:

class Foo(object):
    def _defining_fields(self):
        return ['in Foo']

    def get_defining_fields(self):
        fields = []
        for cls in self.__class__.__mro__:
            if hasattr(cls, '_defining_fields'):
                fields.append(cls._defining_fields(self))
        return fields

¬

撰写回答