Python API设计中的重载(或替代方案)
我有一个很大的现有程序库,现在有一个.NET的接口,我在考虑写一个Python的接口。这个现有的API大量使用了基于签名的重载,也就是说,它有很多静态函数,比如:
Circle(p1, p2, p3) -- Creates a circle through three points
Circle(p, r) -- Creates a circle with given center point and radius
Circle(c1, c2, c3) -- Creates a circle tangent to three curves
有一些情况是相同的输入需要以不同的方式使用,这样基于签名的重载就不管用了,所以我必须使用不同的函数名。比如说:
BezierCurve(p1,p2,p3,p4) -- Bezier curve using given points as control points
BezierCurveThroughPoints(p1,p2,p3,p4) -- Bezier curve passing through given points
我想这个第二种方法(使用不同的函数名)可以在Python的API中到处使用。所以,我会有:
CircleThroughThreePoints(p1, p2, p3)
CircleCenterRadius(p, r)
CircleTangentThreeCurves(c1, c2, c3)
但是这些名字看起来太长了(我不喜欢缩写),而且想出所有这些名字会非常有挑战性,因为这个库有成千上万的函数。
低优先级:
我写代码的努力程度 -- 我不在乎要写很多代码。
性能
高优先级:
让调用者使用和理解起来简单(很多人都是编程新手)。
让我能轻松写出好的文档。
简单性 -- 避免调用者的代码中需要高级概念。
我相信我不是第一个希望在Python中使用基于签名的重载的人。人们通常使用什么变通办法呢?
5 个回答
一种方法是自己写代码来解析参数。这样你就不需要修改API了。你甚至可以写一个装饰器,这样可以重复使用:
import functools
def overload(func):
'''Creates a signature from the arguments passed to the decorated function and passes it as the first argument'''
@functools.wraps(func)
def inner(*args):
signature = tuple(map(type, args))
return func(signature, *args)
return inner
def matches(collection, sig):
'''Returns True if each item in collection is an instance of its respective item in signature'''
if len(sig)!=len(collection):
return False
return all(issubclass(i, j) for i,j in zip(collection, sig))
@overload
def Circle1(sig, *args):
if matches(sig, (Point,)*3):
#do stuff with args
print "3 points"
elif matches(sig, (Point, float)):
#as before
print "point, float"
elif matches(sig, (Curve,)*3):
#and again
print "3 curves"
else:
raise TypeError("Invalid argument signature")
# or even better
@overload
def Circle2(sig, *args):
valid_sigs = {(Point,)*3: CircleThroughThreePoints,
(Point, float): CircleCenterRadius,
(Curve,)*3: CircleTangentThreeCurves
}
try:
return (f for s,f in valid_sigs.items() if matches(sig, s)).next()(*args)
except StopIteration:
raise TypeError("Invalid argument signature")
对API用户的展示:
这部分是最棒的。对于API用户来说,他们只看到这个:
>>> help(Circle)
Circle(*args)
Whatever's in Circle's docstring. You should put info here about valid signatures.
他们可以像你在问题中展示的那样直接调用Circle
。
它是如何工作的:
这个想法是把签名匹配的过程隐藏起来。通过使用一个装饰器来创建一个签名,基本上就是一个包含每个参数类型的元组,并把它作为第一个参数传递给函数。
overload:
当你用@overload
装饰一个函数时,overload
会把这个函数作为参数调用。返回的内容(在这个例子中是inner
)会替代被装饰的函数。functools.wraps
确保新函数有相同的名称、文档字符串等。
Overload是一个相对简单的装饰器。它的作用就是把每个参数的类型做成一个元组,并把这个元组作为第一个参数传递给被装饰的函数。
Circle的第一种实现:
这是最简单的方法。在函数开始时,直接检查签名是否与所有有效的签名匹配。
Circle的第二种实现:
这个方法稍微复杂一点。它的好处是你可以把所有有效的签名放在一个地方定义。返回语句使用生成器从字典中过滤出匹配的有效签名,.next()
只获取第一个。如果没有有效的签名匹配,.next()
会抛出一个StopIteration
异常。
总的来说,这个函数只是返回与匹配签名对应的函数的结果。
最后的说明:
在这段代码中,你会看到很多*args
的用法。当它在函数定义中使用时,它会把所有参数存储在一个名为“args”的列表中。在其他地方,它会展开名为args
的列表,使得每个项都成为函数的一个参数(例如a = func(*args)
)。
我觉得在Python中做一些奇怪的事情来呈现干净的API并不算特别罕见。
你可以使用一个字典,像这样
Circle({'points':[p1,p2,p3]})
Circle({'radius':r})
Circle({'curves':[c1,c2,c3])
然后初始化的时候可以这样写
def __init__(args):
if len(args)>1:
raise SomeError("only pass one of points, radius, curves")
if 'points' in args: {blah}
elsif 'radius' in args: {blahblah}
elsif 'curves' in args: {evenmoreblah}
else: raise SomeError("same as above")
在PyPI上有很多模块可以帮助你实现基于签名的重载和调度,比如multipledispatch、multimethods和Dispatching。虽然我对这些模块没有实际使用经验,但multipledispatch
看起来很符合你的需求,而且文档也很完善。用你提到的圆的例子来说:
from multipledispatch import dispatch
class Point(tuple):
pass
class Curve(object):
pass
@dispatch(Point, Point, Point)
def Circle(point1, point2, point3):
print "Circle(point1, point2, point3): point1 = %r, point2 = %r, point3 = %r" % (point1, point2, point3)
@dispatch(Point, int)
def Circle(centre, radius):
print "Circle(centre, radius): centre = %r, radius = %r" % (centre, radius)
@dispatch(Curve, Curve, Curve)
def Circle(curve1, curve2, curve3):
print "Circle(curve1, curve2, curve3): curve1 = %r, curve2 = %r, curve3 = %r" % (curve1, curve2, curve3)
>>> Circle(Point((10,10)), Point((20,20)), Point((30,30)))
Circle(point1, point2, point3): point1 = (10, 10), point2 = (20, 20), point3 = (30, 30)
>>> p1 = Point((25,10))
>>> p1
(10, 10)
>>> Circle(p1, 100)
Circle(centre, radius): centre = (25, 10), radius = 100
>>> Circle(*(Curve(),)*3)
Circle(curve1, curve2, curve3): curve1 = <__main__.Curve object at 0xa954d0>, curve2 = <__main__.Curve object at 0xa954d0>, curve3 = <__main__.Curve object at 0xa954d0>
>>> Circle()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/mhawke/virtualenvs/urllib3/lib/python2.7/site-packages/multipledispatch/dispatcher.py", line 143, in __call__
func = self.resolve(types)
File "/home/mhawke/virtualenvs/urllib3/lib/python2.7/site-packages/multipledispatch/dispatcher.py", line 184, in resolve
(self.name, str_signature(types)))
NotImplementedError: Could not find signature for Circle: <>
你还可以装饰实例方法,这样就可以提供多个__init__()
的实现,这样挺不错的。如果你在类里面实现了任何实际的行为,比如Circle.draw()
,你就需要一些逻辑来判断绘制圆形所需的值(比如圆心和半径、三个点等等)。不过因为这只是为了提供一组绑定,你可能只需要调用正确的本地代码函数,并传递参数就可以了:
from numbers import Number
from multipledispatch import dispatch
class Point(tuple):
pass
class Curve(object):
pass
class Circle(object):
"A circle class"
# dispatch(Point, (int, float, Decimal....))
@dispatch(Point, Number)
def __init__(self, centre, radius):
"""Circle(Point, Number): create a circle from a Point and radius."""
print "Circle.__init__(): centre %r, radius %r" % (centre, radius)
@dispatch(Point, Point, Point)
def __init__(self, point1, point2, point3):
"""Circle(Point, Point, Point): create a circle from 3 points."""
print "Circle.__init__(): point1 %r, point2 %r, point3 = %r" % (point1, point2, point3)
@dispatch(Curve, Curve, Curve)
def __init__(self, curve1, curve2, curve3):
"""Circle(Curve, Curve, Curve): create a circle from 3 curves."""
print "Circle.__init__(): curve1 %r, curve2 %r, curve3 = %r" % (curve1, curve2, curve3)
__doc__ = '' if __doc__ is None else '{}\n\n'.format(__doc__)
__doc__ += '\n'.join(f.__doc__ for f in __init__.funcs.values())
>>> print Circle.__doc__
A circle class
Circle(Point, Number): create a circle from a Point and radius.
Circle(Point, Point, Point): create a circle from 3 points.
Circle(Curve, Curve, Curve): create a circle from 3 curves.
>>> for num in 10, 10.22, complex(10.22), True, Decimal(100):
... Circle(Point((10,20)), num)
...
Circle.__init__(): centre (10, 20), radius 10
<__main__.Circle object at 0x1d42fd0>
Circle.__init__(): centre (10, 20), radius 10.22
<__main__.Circle object at 0x1e3d890>
Circle.__init__(): centre (10, 20), radius (10.22+0j)
<__main__.Circle object at 0x1d42fd0>
Circle.__init__(): centre (10, 20), radius True
<__main__.Circle object at 0x1e3d890>
Circle.__init__(): centre (10, 20), radius Decimal('100')
<__main__.Circle object at 0x1d42fd0>
>>> Circle(Curve(), Curve(), Curve())
Circle.__init__(): curve1 <__main__.Curve object at 0x1e3db50>, curve2 <__main__.Curve object at 0x1d42fd0>, curve3 = <__main__.Curve object at 0x1d4b1d0>
<__main__.Circle object at 0x1d4b4d0>
>>> p1=Point((10,20))
>>> Circle(*(p1,)*3)
Circle.__init__(): point1 (10, 20), point2 (10, 20), point3 = (10, 20)
<__main__.Circle object at 0x1e3d890>
>>> Circle()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/mhawke/virtualenvs/urllib3/lib/python2.7/site-packages/multipledispatch/dispatcher.py", line 235, in __call__
func = self.resolve(types)
File "/home/mhawke/virtualenvs/urllib3/lib/python2.7/site-packages/multipledispatch/dispatcher.py", line 184, in resolve
(self.name, str_signature(types)))
NotImplementedError: Could not find signature for __init__: <>
这里有几种选择。
你可以写一个构造函数,它可以接受任意数量的参数(使用 *args
和/或 **varargs
语法),然后根据参数的数量和类型做不同的事情。
或者,你可以把其他的构造函数写成类的方法。这些方法被称为“工厂”方法。如果你有多个构造函数,它们接受相同数量的相同类型的对象(就像你提到的 BezierCurve
示例),那么这可能是你唯一的选择。
如果你不介意重写 __new__
而不是 __init__
,你甚至可以同时使用这两种方式,让 __new__
方法处理一种参数形式,而把其他类型的参数交给工厂方法来处理。下面是一个可能的例子,包括多个 __new__
方法签名的文档字符串:
class Circle(object):
"""Circle(center, radius) -> Circle object
Circle(point1, point2, point3) -> Circle object
Circle(curve1, curve2, curve3) -> Circle object
Return a Circle with the provided center and radius. If three points are given,
the center and radius will be computed so that the circle will pass through each
of the points. If three curves are given, the circle's center and radius will
be chosen so that the circle will be tangent to each of them."""
def __new__(cls, *args):
if len(args) == 2:
self = super(Circle, cls).__new__(cls)
self.center, self.radius = args
return self
elif len(args) == 3:
if all(isinstance(arg, Point) for arg in args):
return Circle.through_points(*args)
elif all(isinstance(arg, Curve) for arg in args):
return Circle.tangent_to_curves(*args)
raise TypeError("Invalid arguments to Circle()")
@classmethod
def through_points(cls, point1, point2, point3):
"""from_points(point1, point2, point3) -> Circle object
Return a Circle that touches three points."""
# compute center and radius from the points...
# then call back to the main constructor:
return cls(center, radius)
@classmethod
def tangent_to_curves(cls, curve1, curve2, curve3):
"""from_curves(curve1, curve2, curve3) -> Circle object
Return a Circle that is tangent to three curves."""
# here too, compute center and radius from curves ...
# then call back to the main constructor:
return cls(center, radius)
一种选择是在构造函数中只使用关键字参数,并加入一些逻辑来判断应该使用什么:
class Circle(object):
def __init__(self, points=(), radius=None, curves=()):
if radius and len(points) == 1:
center_point = points[0]
# Create from radius/center point
elif curves and len(curves) == 3:
# create from curves
elif points and len(points) == 3:
# create from points
else:
raise ValueError("Must provide a tuple of three points, a point and a radius, or a tuple of three curves)
你还可以使用类方法,这样可以让API的使用者更方便:
class Circle(object):
def __init__(self, points=(), radius=None, curves=()):
# same as above
@classmethod
def from_points(p1, p2, p3):
return cls(points=(p1, p2, p3))
@classmethod
def from_point_and_radius(cls, point, radius):
return cls(points=(point,), radius=radius)
@classmethod
def from_curves(cls, c1, c2, c3):
return cls(curves=(c1, c2, c3))
用法:
c = Circle.from_points(p1, p2, p3)
c = Circle.from_point_and_radius(p1, r)
c = Circle.from_curves(c1, c2, c3)