在Python中动态装饰递归函数

1 投票
3 回答
61 浏览
提问于 2025-04-12 13:59

我遇到了一个情况,需要在Python中动态地给一个函数的递归调用添加装饰。关键要求是要做到这一点,而不修改当前作用域中的函数。让我来解释一下这个情况以及我目前尝试过的。

假设我有一个函数 traverse_tree,它可以递归地遍历一个二叉树,并返回树中的值。现在,我想给这个函数里的递归调用添加一些额外的信息,比如递归的深度。当我直接用装饰器装饰这个函数时,它的效果是正常的。但是,我想要以动态的方式实现这个,而不修改当前作用域中的函数。


import functools


class Node:
    def __init__(self, value, left=None, right=None):
        self.value = value
        self.left = left
        self.right = right


def generate_tree():
    root = Node(1)
    root.left = Node(2)
    root.right = Node(3)
    root.left.left = Node(4)
    root.left.right = Node(5)
    root.right.left = Node(6)
    root.right.right = Node(7)
    return root


def with_recursion_depth(func):
    """Yield recursion depth alongside original values of an iterator."""
    
    class Depth(int): pass
    depth = Depth(-1)

    def depth_in_value(value, depth) -> bool:
        return isinstance(value, tuple) and len(value) == 2 and value[-1] is depth

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        nonlocal depth
        depth = Depth(depth + 1)
        for value in func(*args, **kwargs):
            if depth_in_value(value, depth):
                yield value
            else:
                yield value, depth
        depth = Depth(depth - 1)

    return wrapper

# 1. using @-syntax
@with_recursion_depth
def traverse_tree(node):
    """Recursively yield values of the binary tree."""
    yield node.value
    if node.left:
        yield from traverse_tree(node.left)
    if node.right:
        yield from traverse_tree(node.right)


root = generate_tree()
for item in traverse_tree(root):
    print(item)

# Output:
# (1, 0)
# (2, 1)
# (4, 2)
# (5, 2)
# (3, 1)
# (6, 2)
# (7, 2)


# 2. Dynamically:  
def traverse_tree(node):
    """Recursively yield values of the binary tree."""
    yield node.value
    if node.left:
        yield from traverse_tree(node.left)
    if node.right:
        yield from traverse_tree(node.right)


root = generate_tree()
for item in with_recursion_depth(traverse_tree)(root):
    print(item)

# Output:
# (1, 0)
# (2, 0)
# (4, 0)
# (5, 0)
# (3, 0)
# (6, 0)
# (7, 0)

看起来问题出在如何给函数内部的递归调用添加装饰。当我动态使用装饰器时,它只装饰了外部函数的调用,而没有装饰函数内部的递归调用。我可以通过重新赋值来实现这个功能(traverse_tree = with_recursion_depth(traverse_tree)),但这样一来,函数在当前作用域中就被修改了。我希望能够动态地实现这个,这样我就可以选择使用未装饰的函数,或者选择性地包装它以获取递归深度的信息。

我希望保持事情简单,尽量避免像字节码操作这样的复杂技术,如果有其他解决方案的话。不过,如果这是必须的路径,我也愿意去探索。我在这方面尝试过,但还没有成功。

import ast


def modify_recursive_calls(func, decorator):

    def decorate_recursive_calls(node):
        if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == func.__name__:
            func_name = ast.copy_location(ast.Name(id=node.func.id, ctx=ast.Load()), node.func)
            decorated_func = ast.Call(
                func=ast.Name(id=decorator.__name__, ctx=ast.Load()),
                args=[func_name],
                keywords=[],
            )
            node.func = decorated_func
        for field, value in ast.iter_fields(node):
            if isinstance(value, list):
                for item in value:
                    if isinstance(item, ast.AST):
                        decorate_recursive_calls(item)
            elif isinstance(value, ast.AST):
                decorate_recursive_calls(value)

    tree = ast.parse(inspect.getsource(func))
    decorate_recursive_calls(tree)
    ast.fix_missing_locations(tree)
    modified_code = compile(tree, filename="<ast>", mode="exec")
    modified_function = types.FunctionType(modified_code.co_consts[1], func.__globals__)
    return modified_function

3 个回答

-1

这里有个小技巧,你可以在一个上下文管理器中修改函数的全局变量:

import contextlib

@contextlib.contextmanager 
def dynamic_with_recursion_depth(f):
    f_name = f.__name__
    f_globals = f.__globals__
    f_globals[f_name] = with_recursion_depth(f)
    try:
        yield
    finally:
        f_globals[f_name] = f

然后可以这样做:

root = generate_tree()
with dynamic_with_recursion_depth(traverse_tree):
    for item in traverse_tree(root):
        print(item)

当然,这种方法有点脆弱,因为如果 traverse_tree 是在另一个模块中定义的,你就得确保从那个模块调用它,比如 import mymodule; mymodule.traverse_tree(root),而 from mymodule import traverse_tree 是不行的。如果这个函数不是在全局范围内定义的,那也会出现问题。

0

为了实现你想要的功能,你的 traverse_tree 方法需要知道它是怎么被调用的,而 这并没有简单的方法。你可以尝试这样做...

def traverse_tree(node, func=None):
    """Recursively yield values of the binary tree."""
    func = func if func else traverse_tree
    yield node.value
    if node.left:
        yield from func(node.left, func=func)
    if node.right:
        yield from func(node.right, func=func)


root = generate_tree()
_traverse = with_recursion_depth(traverse_tree)
for item in _traverse(root, func=_traverse):
    print(item)

...不过这有点像是变通的办法。

1

这是一个可以使用上下文变量的例子:“contextvars”是最近加入到语言中的一个功能,目的是让在任务中运行的异步代码可以传递一些“额外”的信息给嵌套调用,这样这些信息就不会被其他任务调用同样的函数时覆盖或混淆。你可以查看这个链接了解更多:https://docs.python.org/3/library/contextvars.html

它们有点像threading.locals,但是接口使用起来比较麻烦——你最里面的调用可以从一个上下文变量中获取要调用的函数,而不是从全局范围中获取。因此,如果最外层的调用把这个上下文变量设置为装饰过的函数,那么只有在这个“下降”过程中调用的函数会受到影响。

上下文变量的功能很强大。不过,它的接口使用起来非常糟糕,因为你必须先创建一个上下文的副本,然后用这个副本来调用函数,只有这个被调用的函数可以改变上下文变量。而在这个调用之外的代码,总是能看到未改变的值,这在多线程、异步任务等情况下都是一致且可靠的。

用一个更简单的“mymul”递归函数和一个“logger”装饰器,代码可以像这样:

import contextvars

recursive_func = contextvars.ContextVar("recursive_func")


def logger(func):
    def wrapper(*args, **kwargs):
        print(f"{func.__name__} called with {args}  and {kwargs}")
        return func(*args, **kwargs)
    return wrapper


def mymul(a, b):
    if b == 0:
        return 0
    return a + recursive_func.get()(a, b - 1)

recursive_func.set(mymul)

# non logging call:
mymul(3, 4)

# logging call - there must be an "entry point" which
# can change the contextvar from within the new context.
def logging_call_maker(*args, **kwargs):
    recursive_func.set(logger(mymul))
    return mymul(*args, **kwargs)


contextvars.copy_context().run(logging_call_maker, 3, 4)


# Another non logging call:
mymul(3, 4)

这里的关键点是:递归函数是从ContextVar中获取要调用的函数,而不是使用它在全局范围中的名字。

如果你喜欢这种方法,但觉得ContextVar的使用太繁琐,可以让它更简单:只需在这里留言。我几年前开始了一个项目,目的是把这种行为封装成更友好的代码,但因为没有足够的使用案例,我就停止了这个项目。

撰写回答