Python方法能否检查自己是否被内部调用过?

11 投票
3 回答
6394 浏览
提问于 2025-04-17 05:01

假设我有一个Python函数叫做 f 和另一个叫做 fhelp。这个 fhelp 函数是用来自己调用自己的,也就是递归调用。而 f 函数不应该被递归调用。有没有办法让 f 知道自己是否被递归调用了呢?

3 个回答

0

我对jro的回答进行了改进,让它可以处理对象,同时确保它在多线程环境下是安全的:

import threading

class NoRecurse:
    def __init__(self):
        self.lock = threading.RLock()
        self.seen = set()

    def __call__(no_recurse, f):

        def func(self, *args, **kwargs):
            with no_recurse.lock:
                if len([l[2] for l in traceback.extract_stack() if l[2] == f.__name__]) > 0 and self in no_recurse.seen:
                    raise Exception('Recursed')

            with no_recurse.lock:
                no_recurse.seen.add(self)

            r = f(self, *args, **kwargs)

            with no_recurse.lock:
                no_recurse.seen.remove(self)

            return r

        return func
3

你可以使用一个由装饰器设置的标志:

def norecurse(func):
    func.called = False
    def f(*args, **kwargs):
        if func.called:
            print "Recursion!"
            # func.called = False # if you are going to continue execution
            raise Exception
        func.called = True
        result = func(*args, **kwargs)
        func.called = False
        return result
    return f

然后你可以这样做:

@norecurse
def f(some, arg, s):
    do_stuff()

如果在函数运行的时候再次调用了f,那么called会变成True,这时就会抛出一个异常。

16

要实现这个功能,可以使用 traceback 模块:

>>> import traceback
>>> def f(depth=0):
...     print depth, traceback.print_stack()
...     if depth < 2:
...         f(depth + 1)
...
>>> f()
0  File "<stdin>", line 1, in <module>
  File "<stdin>", line 2, in f
 None
1  File "<stdin>", line 1, in <module>
  File "<stdin>", line 4, in f
  File "<stdin>", line 2, in f
 None
2  File "<stdin>", line 1, in <module>
  File "<stdin>", line 4, in f
  File "<stdin>", line 4, in f
  File "<stdin>", line 2, in f
 None

如果调用栈中的任何一项显示代码是从 f 函数调用的,那么这个调用就是(直接或间接)递归的。traceback.extract_stack 方法可以让你轻松获取这些信息。下面示例中的 if len(l[2] ... 语句只是用来计算函数名称完全匹配的次数。为了让这个功能更好看(感谢 agf 提出的这个想法),你可以把它做成一个装饰器:

>>> def norecurse(f):
...     def func(*args, **kwargs):
...         if len([l[2] for l in traceback.extract_stack() if l[2] == f.__name__]) > 0:
...             raise Exception('Recursed')
...         return f(*args, **kwargs)
...     return func
...
>>> @norecurse
... def foo(depth=0):
...     print depth
...     foo(depth + 1)
...
>>> foo()
0
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 5, in func
  File "<stdin>", line 4, in foo
  File "<stdin>", line 5, in func
Exception: Recursed

撰写回答