functools.partial是如何工作的?

342 投票
9 回答
285891 浏览
提问于 2025-04-17 18:38

我一直搞不懂在 functools 里,partial 是怎么工作的。这里有一段代码来自这个链接

>>> sum = lambda x, y : x + y
>>> sum(1, 2)
3
>>> incr = lambda y : sum(1, y)
>>> incr(2)
3
>>> def sum2(x, y):
    return x + y

>>> incr2 = functools.partial(sum2, 1)
>>> incr2(4)
5

在这一行

incr = lambda y : sum(1, y)

我明白无论我传给 incr 什么参数,它都会作为 y 传给 lambda,然后返回 sum(1, y),也就是 1 + y

这一点我明白了。但是我不太理解 incr2(4) 是怎么回事。

这个 4 是怎么作为 x 传入到部分函数里的?在我看来,4 应该替代 sum2。那么 x4 之间有什么关系呢?

9 个回答

87

简单来说,partial 是用来给一个函数的参数设置默认值的,这些参数如果不设置的话就没有默认值。

from functools import partial

def foo(a,b):
    return a+b

bar = partial(foo, a=1) # equivalent to: foo(a=1, b)
bar(b=10)
#11 = 1+10
bar(a=101, b=10)
#111=101+10
176

部分函数非常有用。

比如,在一个“管道式”的函数调用序列中(一个函数的返回值作为下一个函数的输入)。

有时候,在这样的管道中,一个函数需要一个单一的参数,但它前面的函数却返回了两个值

在这种情况下,functools.partial可以帮助你保持这个函数管道的完整性。

这里有一个具体的例子:假设你想根据每个数据点与某个目标的距离来排序数据:

# create some data
import random as RND
fnx = lambda: RND.randint(0, 10)
data = [ (fnx(), fnx()) for c in range(10) ]
target = (2, 4)

import math
def euclid_dist(v1, v2):
    x1, y1 = v1
    x2, y2 = v2
    return math.sqrt((x2 - x1)**2 + (y2 - y1)**2)

为了根据距离目标来排序数据,你当然想这样做:

data.sort(key=euclid_dist)

但你不能这样做——sort方法的key参数只接受一个单一参数的函数。

所以你需要把euclid_dist重写成一个只接受单一参数的函数:

from functools import partial

p_euclid_dist = partial(euclid_dist, target)

p_euclid_dist现在接受一个单一的参数,

>>> p_euclid_dist((3, 3))
  1.4142135623730951

这样你就可以通过传入部分函数来根据距离排序了,作为sort方法的key参数:

data.sort(key=p_euclid_dist)

# verify that it works:
for p in data:
    print(round(p_euclid_dist(p), 3))

    1.0
    2.236
    2.236
    3.606
    4.243
    5.0
    5.831
    6.325
    7.071
    8.602

另外,比如说某个函数的参数在外层循环中变化,但在内层循环中是固定的。通过使用部分函数,你在内层循环迭代时就不需要传入额外的参数,因为修改后的(部分)函数不需要这个参数。

>>> from functools import partial

>>> def fnx(a, b, c):
      return a + b + c

>>> fnx(3, 4, 5)
      12

创建一个部分函数(使用关键字参数)

>>> pfnx = partial(fnx, a=12)

>>> pfnx(b=4, c=5)
     21

你也可以用位置参数来创建部分函数

>>> pfnx = partial(fnx, 12)

>>> pfnx(4, 5)
      21

但这样会出错(例如,用关键字参数创建部分函数后,再用位置参数调用)

>>> pfnx = partial(fnx, a=12)

>>> pfnx(4, 5)
      Traceback (most recent call last):
      File "<pyshell#80>", line 1, in <module>
      pfnx(4, 5)
      TypeError: fnx() got multiple values for keyword argument 'a'

另一个用例:使用Python的multiprocessing库编写分布式代码。使用Pool方法创建一个进程池:

>>> import multiprocessing as MP

>>> # create a process pool:
>>> ppool = MP.Pool()

Pool有一个map方法,但它只接受一个可迭代对象,所以如果你需要传入一个参数列表较长的函数,可以把这个函数重新定义为部分函数,固定住除了一个以外的所有参数:

>>> ppool.map(pfnx, [4, 6, 7, 8])
400

大致来说,partial 的作用类似于这个(除了支持关键字参数等):

def partial(func, *part_args):
    def wrapper(*extra_args):
        return func(*args, *extra_args)            
    return wrapper

所以,当你调用 partial(sum2, 4) 时,你实际上是在创建一个新的函数(准确来说是一个可调用的对象),这个新函数的行为和 sum2 一样,但少了一个位置参数。这个缺失的参数总是用 4 来替代,因此 partial(sum2, 4)(2) == sum2(4, 2)

至于为什么需要这个功能,有很多情况。举个例子,假设你需要把一个函数传递到某个地方,而这个地方期望这个函数有两个参数:

class EventNotifier(object):
    def __init__(self):
        self._listeners = []

    def add_listener(self, callback):
        ''' callback should accept two positional arguments, event and params '''
        self._listeners.append(callback)
        # ...
    
    def notify(self, event, *params):
        for f in self._listeners:
            f(event, params)

但是你手头的函数需要访问一个第三个 context 对象才能正常工作:

def log_event(context, event, params):
    context.log_event("Something happened %s, %s", event, params)

所以,这里有几种解决方案:

一种是自定义对象:

class Listener(object):
   def __init__(self, context):
       self._context = context

   def __call__(self, event, params):
       self._context.log_event("Something happened %s, %s", event, params)


 notifier.add_listener(Listener(context))

另一种是使用 Lambda 表达式:

log_listener = lambda event, params: log_event(context, event, params)
notifier.add_listener(log_listener)

还有一种是使用 partial:

context = get_context()  # whatever
notifier.add_listener(partial(log_event, context))

在这三种方案中,partial 是最简洁和最快的。不过在更复杂的情况下,你可能会想用自定义对象。

撰写回答