如何在对象的多个方法上使用functools.partial,并无序固定参数?

2 投票
1 回答
1176 浏览
提问于 2025-04-15 20:27

我发现functools.partial非常有用,但我希望能够按顺序冻结参数(想冻结的参数不一定总是第一个),而且我想能同时应用到一个类的多个方法上,制作一个代理对象,这个对象的方法和原始对象一样,只是有些方法的参数被固定了(可以把它理解为将部分功能推广到类上)。我希望这样做的时候不修改原始对象,就像partial不会改变它的原始函数一样。

我拼凑出了一个叫做'bind'的版本,类似于functools.partial,它让我可以通过关键字参数来指定参数的顺序。这个部分是有效的:

>>> def foo(x, y):
...     print x, y
...
>>> bar = bind(foo, y=3)
>>> bar(2)
2 3

但是我的代理类不工作,我也不太确定为什么:

>>> class Foo(object):
...     def bar(self, x, y):
...             print x, y
...
>>> a = Foo()
>>> b = PureProxy(a, bar=bind(Foo.bar, y=3))
>>> b.bar(2)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: bar() takes exactly 3 arguments (2 given)

我可能做错了很多事情,因为我只是根据从随机文档、博客中拼凑的信息,以及在所有组件上运行dir()得到的结果来进行操作。如果有人能给我一些建议,告诉我怎么让它工作,或者更好的实现方式,我会非常感激 ;) 有一个细节我不太确定的是,这一切应该如何与描述符互动。代码如下。

from types import MethodType

class PureProxy(object):
    def __init__(self, underlying, **substitutions):
        self.underlying = underlying

        for name in substitutions:
            subst_attr = substitutions[name]
            if hasattr(subst_attr, "underlying"):
                setattr(self, name, MethodType(subst_attr, self, PureProxy))

    def __getattribute__(self, name):
        return getattr(object.__getattribute__(self, "underlying"), name)

def bind(f, *args, **kwargs):
    """ Lets you freeze arguments of a function be certain values. Unlike
    functools.partial, you can freeze arguments by name, which has the bonus
    of letting you freeze them out of order. args will be treated just like
    partial, but kwargs will properly take into account if you are specifying
    a regular argument by name. """
    argspec = inspect.getargspec(f)
    argdict = copy(kwargs)

    if hasattr(f, "im_func"):
        f = f.im_func

    args_idx = 0
    for arg in argspec.args:
        if args_idx >= len(args):
            break

        argdict[arg] = args[args_idx]
        args_idx += 1

    num_plugged = args_idx

    def new_func(*inner_args, **inner_kwargs):
        args_idx = 0
        for arg in argspec.args[num_plugged:]:
            if arg in argdict:
                continue
            if args_idx >= len(inner_args):
                # We can't raise an error here because some remaining arguments
                # may have been passed in by keyword.
                break
            argdict[arg] = inner_args[args_idx]
            args_idx += 1

        f(**dict(argdict, **inner_kwargs))

    new_func.underlying = f

    return new_func

更新:如果有人能从中受益,这里是我最终采用的实现:

from types import MethodType

class PureProxy(object):
    """ Intended usage:
    >>> class Foo(object):
    ...     def bar(self, x, y):
    ...             print x, y
    ...
    >>> a = Foo()
    >>> b = PureProxy(a, bar=FreezeArgs(y=3))
    >>> b.bar(1)
    1 3
    """

    def __init__(self, underlying, **substitutions):
        self.underlying = underlying

        for name in substitutions:
            subst_attr = substitutions[name]
            if isinstance(subst_attr, FreezeArgs):
                underlying_func = getattr(underlying, name)
                new_method_func = bind(underlying_func, *subst_attr.args, **subst_attr.kwargs)
                setattr(self, name, MethodType(new_method_func, self, PureProxy))

    def __getattr__(self, name):
        return getattr(self.underlying, name)

class FreezeArgs(object):
    def __init__(self, *args, **kwargs):
        self.args = args
        self.kwargs = kwargs

def bind(f, *args, **kwargs):
    """ Lets you freeze arguments of a function be certain values. Unlike
    functools.partial, you can freeze arguments by name, which has the bonus
    of letting you freeze them out of order. args will be treated just like
    partial, but kwargs will properly take into account if you are specifying
    a regular argument by name. """
    argspec = inspect.getargspec(f)
    argdict = copy(kwargs)

    if hasattr(f, "im_func"):
        f = f.im_func

    args_idx = 0
    for arg in argspec.args:
        if args_idx >= len(args):
            break

        argdict[arg] = args[args_idx]
        args_idx += 1

    num_plugged = args_idx

    def new_func(*inner_args, **inner_kwargs):
        args_idx = 0
        for arg in argspec.args[num_plugged:]:
            if arg in argdict:
                continue
            if args_idx >= len(inner_args):
                # We can't raise an error here because some remaining arguments
                # may have been passed in by keyword.
                break
            argdict[arg] = inner_args[args_idx]
            args_idx += 1

        f(**dict(argdict, **inner_kwargs))

    return new_func

1 个回答

3

你现在的代码有点问题,叫做“绑定太深了”。你需要把类 PureProxy 里的 def __getattribute__(self, name): 改成 def __getattr__(self, name):。因为 __getattribute__ 会拦截你访问的每一个属性,这样就会绕过你用 setattr(self, name, ... 设置的属性,导致这些设置没有任何效果,这显然不是你想要的。而 __getattr__ 只会在你访问那些“没有定义”的属性时被调用,这样你的 setattr 调用就能正常工作了。

在你重写的这个方法里,你还应该把 object.__getattribute__(self, "underlying") 改成 self.underlying(因为你不再重写 __getattribute__ 了)。我还有其他一些建议,比如用 enumerate 替代你现在用的低级逻辑来处理计数器等等,但这些改动不会改变代码的基本意思。

按照我建议的改动,你的示例代码就能正常工作了(当然你还得继续测试一些更复杂的情况)。顺便说一下,我调试这个问题的方法就是在合适的地方加上 print 语句(虽然这种方法有点老派,但我还是最喜欢这种方式;-)。

撰写回答