在Python中以通用方式为所有子类实现__neg__

3 投票
3 回答
1840 浏览
提问于 2025-04-17 04:03

抱歉问题有点长。

我正在实现可调用对象,想让它们的行为像数学函数一样。我有一个基类,它的 __call__ 方法会抛出 NotImplementedError,所以用户必须继承这个类来定义 __call__ 方法。我的问题是:我该如何在 基类 中定义特殊方法 __neg__,这样子类就能立即拥有预期的行为,而不需要在每个子类中都实现 __neg__ 呢?我对预期行为的理解是,如果 f 是基类的一个实例(或者它的子类),并且 __call__ 方法定义得当,那么 -f 应该是和 f 同一类的实例,拥有和 f 一样的所有属性,唯一不同的就是 __call__ 方法,它应该返回 f__call__ 的负值。

下面是我想表达的一个例子:

class Base(object):
    def __call__(self, *args, **kwargs):
        raise NotImplementedError, 'Please subclass'

    def __neg__(self):
        def call(*args, **kwargs):
            return -self(*args, **kwargs)
        mBase = type('mBase', (Base,), {'__call__': call})
        return mBase()                                                                                                                                                                                                  

class One(Base):
    def __init__(self data):
        self.data = data

    def __call__(self, *args, **kwargs):
        return 1

这个例子有预期的行为:

one = One()
print one()        # Prints  1
minus_one = -one
print minus_one()  # Prints -1

不过这并不是我想要的,因为 minus_one 不是和 one 同一类的实例(但我可以接受这个)。

现在我希望新的实例 minus_one 能继承 one 的所有属性和方法;只有 __call__ 方法应该改变。所以我可以把 __neg__ 改成:

    def __neg__(self):
        def call(*args, **kwargs):
            return -self(*args, **kwargs)
        mBase = type('mBase', (Base,), {'__call__': call})
        new = mBase()

        for n, v in inspect.getmembers(self):
            if n != '__call__':
                setattr(new, n, v)

        return new

这个方法似乎有效。我的问题是:这种策略有什么缺点吗?实现一个通用的 __neg__ 应该是个标准练习,但我在网上找不到相关信息。有没有推荐的替代方案?

提前感谢任何评论。

3 个回答

1

与其创建一个新类型,不如在实例上加一个标记,来表示调用的结果是否需要取反。然后,你可以把真正可以被重写的调用行为放到一个单独的方法中,这个方法不需要特别的处理,作为你自己协议的一部分。

class Base(object):
    def __init__(self):
        self._negate_call = False

    def call_impl(self, *args, **kwargs):
        raise NotImplementedError

    def __call__(self, *args, **kwargs):
        result = self.call_impl(*args, **kwargs)
        return -result if self._negate_call else result

    def __neg__(self):
        other = copy.copy(self)
        other._negate_call = not other._negate_call
        return other
4

你的方法有几个缺点。比如说,你把原始实例的所有成员都复制到新的实例里——如果你的类重写了除了 __call__ 之外的其他特殊方法,这样做就不行了,因为特殊方法在被隐式调用时只会在对象类型的字典里查找。此外,这样还会复制很多其实是从 object 继承来的东西,这些东西其实不需要放在实例的 __dict__ 里。

一个更简单的方法,可以满足你的具体需求,就是让新类型成为原始类型的子类。你可以在 __neg__() 方法里定义一个局部类来实现:

def __neg__(self):
    class Neg(self.__class__):
        def __call__(self_, *args, **kwargs):
            return -self(*args, **kwargs)
    neg = Base.__new__(Neg)
    neg.__dict__ = self.__dict__.copy()
    return neg

这段代码定义了一个新的类 Neg,它是从原始函数的类型派生的,并重写了它的 __call__() 方法。然后,它使用 Base 的构造函数创建这个类的实例——这样做是为了处理 self 的类需要构造参数的情况。最后,我们把直接存储在实例 self 中的所有内容复制到新的实例里。

如果是我来设计这个系统,我会采取完全不同的方法。我会为函数固定一个接口,并且只依赖这个固定的接口来处理每一个函数。我不会费心去把实例的所有属性复制到取反的函数里,而是会这样做:

class Function(object):
    def __neg__(self):
        return NegatedFunction(self)
    def __add__(self, other):
        return SumFunction(self, other)

class NegatedFunction(Function):
    def __init__(self, f):
        self.f = f
    def __call__(self, *args, **kwargs):
        return -self.f(*args, **kwargs)

class SumFunction(Function):
    def __init__(self, *funcs):
        self.funcs = funcs
    def __call__(self, *args, **kwargs):
        return sum(f(*args, **kwargs) for f in self.funcs)

这种方法虽然没有满足你要求的 __neg__() 返回的函数拥有原始函数的所有属性和方法,但我认为这个要求在设计上是有点问题的。我觉得放弃这个要求会让你的设计更简洁、更通用(就像上面例子中包含的 __add__() 操作符所示)。

3

你遇到的基本问题是,__xxx__ 方法只在类上查找,这意味着同一个类的所有实例都会使用相同的 __xxx__ 方法。这就意味着可以使用类似 Cat Plus Plus 提出的某种方法;不过,你也不希望用户还要担心更多特殊名称(比如 _call_impl_negate)。

如果你不介意使用元类这种可能让人头疼的强大功能,那就可以走这条路。元类可以自动添加 _negate 属性(并且会对名称进行处理以避免冲突),还可以把用户写的 __call__ 重命名为 _call,然后创建一个新的 __call__,这个新的 __call__ 会调用旧的 __call__(现在叫 _call ;)),如果需要的话,还会对结果进行取反,然后再返回。

下面是代码:

import copy
import inspect

class MetaFunction(type):
    def __new__(metacls, cls_name, cls_bases, cls_dict):
        result_class = type.__new__(metacls, cls_name, cls_bases, cls_dict)
        if '__call__' in cls_dict:
            original_call = cls_dict['__call__']
            args, varargs, kwargs, defaults = inspect.getargspec(original_call)
            args = args[1:]
            if defaults is None:
                defaults = [''] * len(args)
            else:
                defaults = [''] * (len(args) - len(defaults)) + list(defaults)
            signature = []
            for arg, default in zip(args, defaults):
                if default:
                    signature.append('%s=%s' % (arg, default))
                else:
                    signature.append(arg)
            if varargs is not None:
                signature.append(varargs)
            if kwargs is not None:
                signature.append(kwargs)
            signature = ', '.join(signature)
            passed_args = ', '.join(args)
            new_call = (
                    """def __call__(self, %(signature)s):
                           result = self._call(%(passed_args)s)
                           if self._%(cls_name)s__negate:
                               result = -result
                           return result"""
                           % {
                               'cls_name':cls_name,
                               'signature':signature,
                               'passed_args':passed_args, 
                              })
            eval_dict = {}
            exec new_call in eval_dict
            new_call = eval_dict['__call__']
            new_call.__doc__ = original_call.__doc__
            new_call.__module__ = original_call.__module__
            new_call.__dict__ = original_call.__dict__
            setattr(result_class, '__call__', new_call)
            setattr(result_class, '_call', original_call)
            setattr(result_class, '_%s__negate' % cls_name, False)
            negate = """def __neg__(self):
                            "returns an instance of the same class that returns the negation of __call__"
                            negated = copy.copy(self)
                            negated._%(cls_name)s__negate = not self._%(cls_name)s__negate
                            return negated""" % {'cls_name':cls_name}
            eval_dict = {'copy':copy}
            exec negate in eval_dict
            negate = eval_dict['__neg__']
            negate.__module__ = new_call.__module__
            setattr(result_class, '__neg__', eval_dict['__neg__'])
        return result_class

class Base(object):
    __metaclass__ = MetaFunction

class Power(Base):
    def __init__(self, power):
        "power = the power to raise to"
        self.power = power
    def __call__(self, number):
        "raises number to power"
        return number ** self.power

还有一个例子:

--> square = Power(2)
--> neg_square = -square
--> square(9)
81
--> neg_square(9)
-81

虽然元类的代码本身可能比较复杂,但生成的对象使用起来非常简单。公平地说,MetaFunction 中的大部分代码和复杂性都是因为要重写 __call__,以保持调用签名并使得反射功能有用……所以在帮助文档中,你看到的不是 __call__(*args, *kwargs),而是:

Help on Power in module test object:

class Power(Base)
 |  Method resolution order:
 |      Power
 |      Base
 |      __builtin__.object
 |
 |  Methods defined here:
 |
 |  __call__(self, number)
 |      raises number to power
 |
 |  __init__(self, power)
 |      power = the power to raise to
 |
 |  __neg__(self)
 |      returns an instance of the same class that returns the negation of __call__

撰写回答