装饰器如何在不改变函数签名的情况下传递变量?

9 投票
4 回答
4056 浏览
提问于 2025-05-01 16:24

首先,我得承认我想做的事情可能会被认为是傻或者有点邪恶,但我还是想看看在Python中能不能做到。

假设我有一个函数装饰器,它可以接受一些关键字参数来定义变量,我想在被装饰的函数里访问这些变量。我可能会这样做:

def more_vars(**extras):
    def wrapper(f):
        @wraps(f)
        def wrapped(*args, **kwargs):
            return f(extras, *args, **kwargs)
        return wrapped
    return wrapper

现在我可以这样做:

@more_vars(a='hello', b='world')
def test(deco_vars, x, y):
    print(deco_vars['a'], deco_vars['b'])
    print(x, y)

test(1, 2)
# Output:
# hello world
# 1 2

我不喜欢的地方是,当你使用这个装饰器时,你必须修改函数的调用方式,除了加上装饰器外,还要多加一个变量。而且,如果你查看这个函数的帮助文档,你会看到一个额外的变量,这个变量在调用函数时并不需要使用:

help(test)
# Output:
# Help on function test in module __main__:
#
# test(deco_vars, x, y)

这让人感觉用户在调用这个函数时应该传入三个参数,但显然这样是行不通的。所以你还得在文档字符串里加一句,说明第一个参数并不是接口的一部分,只是实现细节,应该被忽略。不过这样感觉有点糟糕。有没有办法做到这一点,而不需要把这些变量放在全局范围内?理想情况下,我希望它看起来像这样:

@more_vars(a='hello', b='world')
def test(x, y):
    print(a, b)
    print(x, y)

test(1, 2)
# Output:
# hello world
# 1 2
help(test)
# Output:
# Help on function test in module __main__:
#
# test(x, y)

如果有的话,我愿意接受仅限于Python 3的解决方案。

暂无标签

4 个回答

0

听起来你遇到的问题是,help 显示的是原始 test 函数的签名,而不是你想要的包装函数的签名。

之所以会这样,是因为 wraps(其实是它调用的 update_wrapper)会把原函数的签名直接复制到包装函数上。

你可以决定到底要复制哪些内容,不要复制哪些内容。如果你想做的事情比较简单,只需要从默认的 WRAPPER_ASSIGNMENTSWRAPPER_UPDATES 中筛选掉一些东西就行。如果你想改动的内容比较复杂,可能就需要自己修改 update_wrapper,用你自己的版本。不过,functools 这个模块在文档顶部有链接到 源代码,因为它的设计就是为了让人容易阅读和理解。

在你的情况下,可能只需要用 wraps(f, updated=[]) 就能解决,或者你可能想做一些更复杂的事情,比如使用 inspect.signature 来获取 f 的签名,然后修改它,去掉第一个参数,专门围绕这个构建一个包装函数,甚至可以欺骗 inspect 模块。

3

编辑:为了更易读而修改了答案。最新的答案在最上面,原始内容在下面。

如果我理解得没错的话:

  • 你想在 @more_vars 装饰器中定义新的参数作为关键字。
  • 你想在被装饰的函数中使用这些参数。
  • 而且你希望这些参数对普通用户是隐藏的(公开的函数签名仍然应该是正常的签名)。

可以看看我库中的 @with_partial 装饰器,它可以直接提供这种功能:

from makefun import with_partial

@with_partial(a='hello', b='world')
def test(a, b, x, y):
    """Here is a doc"""
    print(a, b)
    print(x, y)

它会产生预期的输出,并且文档字符串也会相应修改:

test(1, 2)
help(test)

产生的结果是

hello world
1 2
Help on function test in module <...>:

test(x, y)
    <This function is equivalent to 'test(x, y, a=hello, b=world)', see original 'test' doc below.>
    Here is a doc

为了回答你评论中的问题,makefun 中的函数创建策略和著名的 decorator 库是完全一样的:compile + exec。这里没有什么魔法,但 decorator 多年来在实际应用中一直使用这个技巧,所以它非常可靠。可以查看 源代码 中的 def _make

注意,makefun 库还提供了一个 partial(f, *args, **kwargs) 函数,如果你出于某种原因想自己创建装饰器的话(下面有灵感的来源)。


如果你想手动实现,这里有一个解决方案,应该能按你期望的那样工作,它依赖于 makefun 提供的 wraps 函数来修改公开的函数签名。

from makefun import wraps, remove_signature_parameters

def more_vars(**extras):
    def wrapper(f):
        # (1) capture the signature of the function to wrap and remove the invisible
        func_sig = signature(f)
        new_sig = remove_signature_parameters(func_sig, 'invisible_args')

        # (2) create a wrapper with the new signature
        @wraps(f, new_sig=new_sig)
        def wrapped(*args, **kwargs):
            # inject the invisible args again
            kwargs['invisible_args'] = extras
            return f(*args, **kwargs)

        return wrapped
    return wrapper

你可以测试一下它是否有效:

@more_vars(a='hello', b='world')
def test(x, y, invisible_args):
    a = invisible_args['a']
    b = invisible_args['b']
    print(a, b)
    print(x, y)

test(1, 2)
help(test)

如果你使用 decopatch 来去掉多余的嵌套层级,甚至可以让装饰器的定义更简洁:

from decopatch import DECORATED
from makefun import wraps, remove_signature_parameters

@function_decorator
def more_vars(f=DECORATED, **extras):
    # (1) capture the signature of the function to wrap and remove the invisible
    func_sig = signature(f)
    new_sig = remove_signature_parameters(func_sig, 'invisible_args')

    # (2) create a wrapper with the new signature
    @wraps(f, new_sig=new_sig)
    def wrapped(*args, **kwargs):
        kwargs['invisible_args'] = extras
        return f(*args, **kwargs)

    return wrapped

最后,如果你不想依赖任何外部库,最符合 Python 风格的方法是创建一个函数工厂(但这样你就不能把它作为装饰器使用了):

def make_test(a, b, name=None):
    def test(x, y):
        print(a, b)
        print(x, y)
    if name is not None:
        test.__name__ = name
    return test

test = make_test(a='hello', b='world')
test2 = make_test(a='hello', b='there', name='test2')

顺便说一下,我是 makefundecopatch 的作者;)

3

你可以通过一些小技巧,把传给装饰器的变量插入到函数的局部变量中:

import sys
from functools import wraps
from types import FunctionType


def is_python3():
    return sys.version_info >= (3, 0)


def more_vars(**extras):
    def wrapper(f):
        @wraps(f)
        def wrapped(*args, **kwargs):
            fn_globals = {}
            fn_globals.update(globals())
            fn_globals.update(extras)
            if is_python3():
                func_code = '__code__'
            else:
                func_code = 'func_code'
            call_fn = FunctionType(getattr(f, func_code), fn_globals)
            return call_fn(*args, **kwargs)
        return wrapped
    return wrapper


@more_vars(a="hello", b="world")
def test(x, y):
    print("locals: {}".format(locals()))
    print("x: {}".format(x))
    print("y: {}".format(y))
    print("a: {}".format(a))
    print("b: {}".format(b))


if __name__ == "__main__":
    test(1, 2)

这样做可以吗?当然可以!这样做应该吗?可能不太应该!

(代码可以在 这里 找到。)

0

我找到了解决这个问题的方法,虽然这个方法在大多数标准下几乎肯定比问题本身还要糟糕。通过一些巧妙的方式重写被装饰函数的字节码,你可以把所有对某个变量名的引用重定向到一个你可以动态创建的新闭包。这种方法只适用于标准的CPython,我只在3.7版本上测试过。

import inspect

from dis import opmap, Bytecode
from types import FunctionType, CodeType

def more_vars(**vars):
    '''Decorator to inject more variables into a function.'''

    def wrapper(f):
        code = f.__code__
        new_freevars = code.co_freevars + tuple(vars.keys())
        new_globals = [var for var in code.co_names if var not in vars.keys()]
        new_locals = [var for var in code.co_varnames if var not in vars.keys()]
        payload = b''.join(
            filtered_bytecode(f, new_freevars, new_globals, new_locals))
        new_code = CodeType(code.co_argcount,
                            code.co_kwonlyargcount,
                            len(new_locals),
                            code.co_stacksize,
                            code.co_flags & ~inspect.CO_NOFREE,
                            payload,
                            code.co_consts,
                            tuple(new_globals),
                            tuple(new_locals),
                            code.co_filename,
                            code.co_name,
                            code.co_firstlineno,
                            code.co_lnotab,
                            code.co_freevars + tuple(vars.keys()),
                            code.co_cellvars)
        closure = tuple(get_cell(v) for (k, v) in vars.items())
        return FunctionType(new_code, f.__globals__, f.__name__, f.__defaults__,
                            (f.__closure__ or ()) + closure)
    return wrapper

def get_cell(val=None):
    '''Create a closure cell object with initial value.'''

    # If you know a better way to do this, I'd like to hear it.
    x = val
    def closure():
        return x  # pragma: no cover
    return closure.__closure__[0]

def filtered_bytecode(func, freevars, globals, locals):
    '''Get the bytecode for a function with adjusted closed variables

    Any references to globlas or locals in the bytecode which exist in the
    freevars are modified to reference the freevars instead.

    '''
    opcode_map = {
        opmap['LOAD_FAST']: opmap['LOAD_DEREF'],
        opmap['STORE_FAST']: opmap['STORE_DEREF'],
        opmap['LOAD_GLOBAL']: opmap['LOAD_DEREF'],
        opmap['STORE_GLOBAL']: opmap['STORE_DEREF']
    }
    freevars_map = {var: idx for (idx, var) in enumerate(freevars)}
    globals_map = {var: idx for (idx, var) in enumerate(globals)}
    locals_map = {var: idx for (idx, var) in enumerate(locals)}

    for instruction in Bytecode(func):
        if instruction.opcode not in opcode_map:
            yield bytes([instruction.opcode, instruction.arg or 0])
        elif instruction.argval in freevars_map:
            yield bytes([opcode_map[instruction.opcode],
                         freevars_map[instruction.argval]])
        elif 'GLOBAL' in instruction.opname:
            yield bytes([instruction.opcode,
                         globals_map[instruction.argval]])
        elif 'FAST' in instruction.opname:
            yield bytes([instruction.opcode,
                         locals_map[instruction.argval]])

这个方法的表现正是我想要的:

In [1]: @more_vars(a='hello', b='world')
   ...: def test(x, y):
   ...:     print(a, b)
   ...:     print(x, y)
   ...:

In [2]: test(1, 2)
hello world
1 2

In [3]: help(test)
Help on function test in module __main__:

test(x, y)

不过,这个方法几乎肯定不适合在生产环境中使用。如果没有一些边缘情况表现得出乎意料,我会感到惊讶,甚至可能会导致程序崩溃。我可能会把这个归类为“教育性好奇心”的内容。

撰写回答