Python中的重载装饰器

21 投票
4 回答
35631 浏览
提问于 2025-04-17 09:08

我知道在Python中,写函数时不应该过于关注参数的类型,但有时候类型是无法忽视的,因为不同类型的处理方式不同。

在你的函数里加一堆isinstance检查看起来很糟糕;有没有什么函数装饰器可以实现函数重载呢?类似这样的:

@overload(str)
def func(val):
    print('This is a string')

@overload(int)
def func(val):
    print('This is an int')

更新:

这是我在David Zaslavsky的回答下留下的一些评论:

经过一些修改,这个方法对我来说挺合适的。我注意到你实现中的一个限制,因为你用func.__name__作为字典的键,这样会导致模块之间的名称冲突,这并不总是理想的。[继续]

[继续] 比如说,如果我有一个模块重载了func,而另一个完全不相关的模块也重载了func,那么这些重载就会冲突,因为函数分发字典是全局的。这个字典应该以某种方式局部化到模块中。而且不仅如此,它还应该支持某种“继承”。[继续]

[继续] 这里的“继承”是指:假设我有一个模块first,里面有一些重载。然后还有两个不相关的模块,它们都导入了first;这两个模块都在原有的基础上添加新的重载。这两个模块应该能够使用first中的重载,但它们新添加的重载不应该在模块之间冲突。(想想其实这真的很难做到。)

这些问题可能通过稍微改变装饰器的语法来解决:

first.py

@overload(str, str)
def concatenate(a, b):
    return a + b

@concatenate.overload(int, int)
def concatenate(a, b):
    return str(a) + str(b)

second.py

from first import concatenate

@concatenate.overload(float, str)
def concatenate(a, b):
    return str(a) + b

4 个回答

17

是的,在typing库中有一个叫做重载装饰器的东西,可以帮助我们让复杂的类型提示变得更简单。

from collections.abc import Sequence
from typing import overload


@overload
def double(input_: int) -> int:
    ...


@overload
def double(input_: Sequence[int]) -> list[int]:
    ...


def double(input_: int | Sequence[int]) -> int | list[int]:
    if isinstance(input_, Sequence):
        return [i * 2 for i in input_]
    return input_ * 2

想了解更多细节,可以查看这个链接

34

从Python 3.4开始,functools模块支持一个叫做@singledispatch的装饰器。它的工作原理是这样的:

from functools import singledispatch


@singledispatch
def func(val):
    raise NotImplementedError


@func.register
def _(val: str):
    print('This is a string')


@func.register
def _(val: int):
    print('This is an int')

用法

func("test") --> "This is a string"
func(1)      --> "This is an int"
func(None)   --> NotImplementedError
13

快速回答:在PyPI上有一个overload包,它实现了比我下面描述的更强大的功能,虽然语法稍有不同。这个包声明只支持Python 3,但看起来只需要稍微修改(如果需要的话,我还没尝试过)就可以让它在Python 2上工作。


详细回答:在一些编程语言中,你可以重载函数,也就是说可以用相同的函数名来定义不同的功能。函数的名字会和它的类型信息一起使用,这样在定义和调用函数时都能知道具体的类型。当编译器或解释器查找函数定义时,它会用函数名和参数的类型来决定调用哪个函数。因此,在Python中实现重载的合理方法是创建一个包装器,这个包装器会同时使用函数名和参数类型来找到正确的函数。

下面是一个简单的实现:

from collections import defaultdict

def determine_types(args, kwargs):
    return tuple([type(a) for a in args]), \
           tuple([(k, type(v)) for k,v in kwargs.iteritems()])

function_table = defaultdict(dict)
def overload(arg_types=(), kwarg_types=()):
    def wrap(func):
        named_func = function_table[func.__name__]
        named_func[arg_types, kwarg_types] = func
        def call_function_by_signature(*args, **kwargs):
            return named_func[determine_types(args, kwargs)](*args, **kwargs)
        return call_function_by_signature
    return wrap

overload函数应该接收两个可选参数,一个是元组,表示所有位置参数的类型,另一个是元组的元组,表示所有关键字参数的名称和类型的对应关系。这里有一个使用示例:

>>> @overload((str, int))
... def f(a, b):
...     return a * b

>>> @overload((int, int))
... def f(a, b):
...     return a + b

>>> print f('a', 2)
aa
>>> print f(4, 2)
6

>>> @overload((str,), (('foo', int), ('bar', float)))
... def g(a, foo, bar):
...     return foo*a + str(bar)

>>> @overload((str,), (('foo', float), ('bar', float)))
... def g(a, foo, bar):
...     return a + str(foo*bar)

>>> print g('a', foo=7, bar=4.4)
aaaaaaa4.4
>>> print g('b', foo=7., bar=4.4)
b30.8

这个实现的不足之处包括:

  • 它并没有实际检查被装饰的函数是否与传给装饰器的参数兼容。你可以写

    @overload((str, int))
    def h():
        return 0
    

    但在调用这个函数时会出错。

  • 它没有优雅地处理当没有对应于传入参数类型的重载版本时的情况(如果能抛出更详细的错误信息就好了)。

  • 它区分了命名参数和位置参数,所以像

    g('a', 7, bar=4.4)
    

    这样的用法是行不通的。

  • 使用这个实现时,涉及到很多嵌套的括号,比如在g的定义中。
  • 正如评论中提到的,这个实现没有处理在不同模块中同名的函数。

我认为这些问题都可以通过一些调整来解决。特别是,名称冲突的问题可以通过将调度表作为装饰器返回的函数的一个属性来轻松解决。但正如我所说,这只是一个简单的示例,用来演示如何实现重载的基本概念。

撰写回答