Python 函数重载

334 投票
20 回答
385104 浏览
提问于 2025-04-16 20:03

我知道Python不支持方法重载,但我遇到了一个问题,感觉用Python的方式解决起来有点麻烦。

我正在制作一个游戏,角色需要发射各种子弹,但我该如何写不同的函数来创建这些子弹呢?比如说,我有一个函数可以创建从A点飞到B点的子弹,并且有一个给定的速度。我会写一个这样的函数:

def add_bullet(sprite, start, headto, speed):
    # Code ...

但是我还想写其他函数来创建不同类型的子弹,比如:

def add_bullet(sprite, start, direction, speed):
def add_bullet(sprite, start, headto, spead, acceleration):
def add_bullet(sprite, script): # For bullets that are controlled by a script
def add_bullet(sprite, curve, speed): # for bullets with curved paths
# And so on ...

还有很多其他的变化。我想知道有没有更好的方法来处理这些,而不是使用那么多的关键字参数,因为这样看起来有点乱。每个函数都改个名字也不好,因为你最后可能会得到 add_bullet1add_bullet2,或者 add_bullet_with_really_long_name 这样的名字。

关于一些回答:

  1. 不,我不能创建一个子弹类的层次结构,因为那样太慢了。实际管理子弹的代码是在C语言中,我的函数只是C API的封装。

  2. 我知道有关键字参数,但检查各种参数组合真的很烦人,不过像 acceleration=0 这样的默认参数确实帮了我不少忙。

20 个回答

107

你可以自己动手实现一个函数重载的解决方案。这个方法是从Guido van Rossum的文章中复制过来的,文章讲的是多方法(因为在Python中,多方法和重载之间几乎没有区别):

registry = {}

class MultiMethod(object):
    def __init__(self, name):
        self.name = name
        self.typemap = {}
    def __call__(self, *args):
        types = tuple(arg.__class__ for arg in args) # a generator expression!
        function = self.typemap.get(types)
        if function is None:
            raise TypeError("no match")
        return function(*args)
    def register(self, types, function):
        if types in self.typemap:
            raise TypeError("duplicate registration")
        self.typemap[types] = function


def multimethod(*types):
    def register(function):
        name = function.__name__
        mm = registry.get(name)
        if mm is None:
            mm = registry[name] = MultiMethod(name)
        mm.register(types, function)
        return mm
    return register

使用方法如下:

from multimethods import multimethod
import unittest

# 'overload' makes more sense in this case
overload = multimethod

class Sprite(object):
    pass

class Point(object):
    pass

class Curve(object):
    pass

@overload(Sprite, Point, Direction, int)
def add_bullet(sprite, start, direction, speed):
    # ...

@overload(Sprite, Point, Point, int, int)
def add_bullet(sprite, start, headto, speed, acceleration):
    # ...

@overload(Sprite, str)
def add_bullet(sprite, script):
    # ...

@overload(Sprite, Curve, speed)
def add_bullet(sprite, curve, speed):
    # ...

目前最严格的限制有:

  • 不支持方法,只支持不是类成员的函数;
  • 不处理继承;
  • 不支持可变关键字参数(kwargs);
  • 注册新函数必须在导入时进行,这个操作不是线程安全的。
126

Python确实支持你所说的“方法重载”。实际上,你刚才描述的内容在Python中很简单,可以用很多不同的方法来实现,但我会选择以下这种方式:

class Character(object):
    # your character __init__ and other methods go here

    def add_bullet(self, sprite=default, start=default, 
                 direction=default, speed=default, accel=default, 
                  curve=default):
        # do stuff with your arguments

在上面的代码中,default是一个合理的默认值,或者你也可以用None。这样,你可以只传入你感兴趣的参数,Python会自动使用默认值。

你也可以这样做:

class Character(object):
    # your character __init__ and other methods go here

    def add_bullet(self, **kwargs):
        # here you can unpack kwargs as (key, values) and
        # do stuff with them, and use some global dictionary
        # to provide default values and ensure that ``key``
        # is a valid argument...

        # do stuff with your arguments

另一种选择是直接把想要的函数连接到类或实例上:

def some_implementation(self, arg1, arg2, arg3):
  # implementation
my_class.add_bullet = some_implementation_of_add_bullet

还有一种方法是使用抽象工厂模式:

class Character(object):
   def __init__(self, bfactory, *args, **kwargs):
       self.bfactory = bfactory
   def add_bullet(self):
       sprite = self.bfactory.sprite()
       speed = self.bfactory.speed()
       # do stuff with your sprite and speed

class pretty_and_fast_factory(object):
    def sprite(self):
       return pretty_sprite
    def speed(self):
       return 10000000000.0

my_character = Character(pretty_and_fast_factory(), a1, a2, kw1=v1, kw2=v2)
my_character.add_bullet() # uses pretty_and_fast_factory

# now, if you have another factory called "ugly_and_slow_factory" 
# you can change it at runtime in python by issuing
my_character.bfactory = ugly_and_slow_factory()

# In the last example you can see abstract factory and "method
# overloading" (as you call it) in action 
266

你问的这个问题叫做 多重调度。可以看看 Julia 语言的例子,它展示了不同类型的调度。

不过,在看这些之前,我们先来聊聊为什么在 Python 中 重载 其实并不是你想要的。

为什么不重载?

首先,需要理解重载的概念,以及为什么它不适用于 Python。

在一些可以在编译时区分数据类型的语言中,选择不同的函数可以在编译时进行。创建这些替代函数以便在编译时选择的行为通常被称为重载函数。 (维基百科)

Python 是一种 动态类型 的语言,所以重载的概念在这里并不适用。不过,也不是说就没有办法,因为我们可以在运行时创建这样的 替代函数

在那些在运行时才确定数据类型的编程语言中,选择替代函数必须在运行时进行,基于动态确定的函数参数类型。以这种方式选择替代实现的函数通常被称为 多方法。 (维基百科)

所以我们应该能够在 Python 中实现 多方法,或者说 多重调度

多重调度

多方法也被称为 多重调度

多重调度或多方法是一些面向对象编程语言的一个特性,在这种特性下,函数或方法可以根据多个参数的运行时(动态)类型进行动态调度。 (维基百科)

Python 默认不支持这个功能1,但实际上,有一个很棒的 Python 包叫 multipledispatch,它正好可以实现这个功能。

解决方案

下面是我们如何使用 multipledispatch2 包来实现你的方法:

>>> from multipledispatch import dispatch
>>> from collections import namedtuple
>>> from types import *  # we can test for lambda type, e.g.:
>>> type(lambda a: 1) == LambdaType
True

>>> Sprite = namedtuple('Sprite', ['name'])
>>> Point = namedtuple('Point', ['x', 'y'])
>>> Curve = namedtuple('Curve', ['x', 'y', 'z'])
>>> Vector = namedtuple('Vector', ['x','y','z'])

>>> @dispatch(Sprite, Point, Vector, int)
... def add_bullet(sprite, start, direction, speed):
...     print("Called Version 1")
...
>>> @dispatch(Sprite, Point, Point, int, float)
... def add_bullet(sprite, start, headto, speed, acceleration):
...     print("Called version 2")
...
>>> @dispatch(Sprite, LambdaType)
... def add_bullet(sprite, script):
...     print("Called version 3")
...
>>> @dispatch(Sprite, Curve, int)
... def add_bullet(sprite, curve, speed):
...     print("Called version 4")
...

>>> sprite = Sprite('Turtle')
>>> start = Point(1,2)
>>> direction = Vector(1,1,1)
>>> speed = 100 #km/h
>>> acceleration = 5.0 #m/s**2
>>> script = lambda sprite: sprite.x * 2
>>> curve = Curve(3, 1, 4)
>>> headto = Point(100, 100) # somewhere far away

>>> add_bullet(sprite, start, direction, speed)
Called Version 1

>>> add_bullet(sprite, start, headto, speed, acceleration)
Called version 2

>>> add_bullet(sprite, script)
Called version 3

>>> add_bullet(sprite, curve, speed)
Called version 4
  1. Python 3 目前支持 单重调度
  2. 注意不要在多线程环境中使用 multipledispatch,否则会出现奇怪的行为。

撰写回答