在Python中动态装饰递归函数
我遇到了一个情况,需要在Python中动态地给一个函数的递归调用添加装饰。关键要求是要做到这一点,而不修改当前作用域中的函数。让我来解释一下这个情况以及我目前尝试过的。
假设我有一个函数 traverse_tree
,它可以递归地遍历一个二叉树,并返回树中的值。现在,我想给这个函数里的递归调用添加一些额外的信息,比如递归的深度。当我直接用装饰器装饰这个函数时,它的效果是正常的。但是,我想要以动态的方式实现这个,而不修改当前作用域中的函数。
import functools
class Node:
def __init__(self, value, left=None, right=None):
self.value = value
self.left = left
self.right = right
def generate_tree():
root = Node(1)
root.left = Node(2)
root.right = Node(3)
root.left.left = Node(4)
root.left.right = Node(5)
root.right.left = Node(6)
root.right.right = Node(7)
return root
def with_recursion_depth(func):
"""Yield recursion depth alongside original values of an iterator."""
class Depth(int): pass
depth = Depth(-1)
def depth_in_value(value, depth) -> bool:
return isinstance(value, tuple) and len(value) == 2 and value[-1] is depth
@functools.wraps(func)
def wrapper(*args, **kwargs):
nonlocal depth
depth = Depth(depth + 1)
for value in func(*args, **kwargs):
if depth_in_value(value, depth):
yield value
else:
yield value, depth
depth = Depth(depth - 1)
return wrapper
# 1. using @-syntax
@with_recursion_depth
def traverse_tree(node):
"""Recursively yield values of the binary tree."""
yield node.value
if node.left:
yield from traverse_tree(node.left)
if node.right:
yield from traverse_tree(node.right)
root = generate_tree()
for item in traverse_tree(root):
print(item)
# Output:
# (1, 0)
# (2, 1)
# (4, 2)
# (5, 2)
# (3, 1)
# (6, 2)
# (7, 2)
# 2. Dynamically:
def traverse_tree(node):
"""Recursively yield values of the binary tree."""
yield node.value
if node.left:
yield from traverse_tree(node.left)
if node.right:
yield from traverse_tree(node.right)
root = generate_tree()
for item in with_recursion_depth(traverse_tree)(root):
print(item)
# Output:
# (1, 0)
# (2, 0)
# (4, 0)
# (5, 0)
# (3, 0)
# (6, 0)
# (7, 0)
看起来问题出在如何给函数内部的递归调用添加装饰。当我动态使用装饰器时,它只装饰了外部函数的调用,而没有装饰函数内部的递归调用。我可以通过重新赋值来实现这个功能(traverse_tree = with_recursion_depth(traverse_tree)
),但这样一来,函数在当前作用域中就被修改了。我希望能够动态地实现这个,这样我就可以选择使用未装饰的函数,或者选择性地包装它以获取递归深度的信息。
我希望保持事情简单,尽量避免像字节码操作这样的复杂技术,如果有其他解决方案的话。不过,如果这是必须的路径,我也愿意去探索。我在这方面尝试过,但还没有成功。
import ast
def modify_recursive_calls(func, decorator):
def decorate_recursive_calls(node):
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == func.__name__:
func_name = ast.copy_location(ast.Name(id=node.func.id, ctx=ast.Load()), node.func)
decorated_func = ast.Call(
func=ast.Name(id=decorator.__name__, ctx=ast.Load()),
args=[func_name],
keywords=[],
)
node.func = decorated_func
for field, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
decorate_recursive_calls(item)
elif isinstance(value, ast.AST):
decorate_recursive_calls(value)
tree = ast.parse(inspect.getsource(func))
decorate_recursive_calls(tree)
ast.fix_missing_locations(tree)
modified_code = compile(tree, filename="<ast>", mode="exec")
modified_function = types.FunctionType(modified_code.co_consts[1], func.__globals__)
return modified_function
3 个回答
这里有个小技巧,你可以在一个上下文管理器中修改函数的全局变量:
import contextlib
@contextlib.contextmanager
def dynamic_with_recursion_depth(f):
f_name = f.__name__
f_globals = f.__globals__
f_globals[f_name] = with_recursion_depth(f)
try:
yield
finally:
f_globals[f_name] = f
然后可以这样做:
root = generate_tree()
with dynamic_with_recursion_depth(traverse_tree):
for item in traverse_tree(root):
print(item)
当然,这种方法有点脆弱,因为如果 traverse_tree
是在另一个模块中定义的,你就得确保从那个模块调用它,比如 import mymodule; mymodule.traverse_tree(root)
,而 from mymodule import traverse_tree
是不行的。如果这个函数不是在全局范围内定义的,那也会出现问题。
为了实现你想要的功能,你的 traverse_tree
方法需要知道它是怎么被调用的,而 这并没有简单的方法。你可以尝试这样做...
def traverse_tree(node, func=None):
"""Recursively yield values of the binary tree."""
func = func if func else traverse_tree
yield node.value
if node.left:
yield from func(node.left, func=func)
if node.right:
yield from func(node.right, func=func)
root = generate_tree()
_traverse = with_recursion_depth(traverse_tree)
for item in _traverse(root, func=_traverse):
print(item)
...不过这有点像是变通的办法。
这是一个可以使用上下文变量的例子:“contextvars”是最近加入到语言中的一个功能,目的是让在任务中运行的异步代码可以传递一些“额外”的信息给嵌套调用,这样这些信息就不会被其他任务调用同样的函数时覆盖或混淆。你可以查看这个链接了解更多:https://docs.python.org/3/library/contextvars.html
它们有点像threading.locals
,但是接口使用起来比较麻烦——你最里面的调用可以从一个上下文变量中获取要调用的函数,而不是从全局范围中获取。因此,如果最外层的调用把这个上下文变量设置为装饰过的函数,那么只有在这个“下降”过程中调用的函数会受到影响。
上下文变量的功能很强大。不过,它的接口使用起来非常糟糕,因为你必须先创建一个上下文的副本,然后用这个副本来调用函数,只有这个被调用的函数可以改变上下文变量。而在这个调用之外的代码,总是能看到未改变的值,这在多线程、异步任务等情况下都是一致且可靠的。
用一个更简单的“mymul”递归函数和一个“logger”装饰器,代码可以像这样:
import contextvars
recursive_func = contextvars.ContextVar("recursive_func")
def logger(func):
def wrapper(*args, **kwargs):
print(f"{func.__name__} called with {args} and {kwargs}")
return func(*args, **kwargs)
return wrapper
def mymul(a, b):
if b == 0:
return 0
return a + recursive_func.get()(a, b - 1)
recursive_func.set(mymul)
# non logging call:
mymul(3, 4)
# logging call - there must be an "entry point" which
# can change the contextvar from within the new context.
def logging_call_maker(*args, **kwargs):
recursive_func.set(logger(mymul))
return mymul(*args, **kwargs)
contextvars.copy_context().run(logging_call_maker, 3, 4)
# Another non logging call:
mymul(3, 4)
这里的关键点是:递归函数是从ContextVar中获取要调用的函数,而不是使用它在全局范围中的名字。
如果你喜欢这种方法,但觉得ContextVar的使用太繁琐,可以让它更简单:只需在这里留言。我几年前开始了一个项目,目的是把这种行为封装成更友好的代码,但因为没有足够的使用案例,我就停止了这个项目。