在Python类中优雅地支持等价(“相等”)的方法

2024-04-20 12:23:26 发布

您现在位置:Python中文网/ 问答频道 /正文

在编写自定义类时,通常需要通过==!=运算符来实现等价性。在Python中,这是通过分别实现__eq____ne__特殊方法实现的。我发现最简单的方法是以下方法:

class Foo:
    def __init__(self, item):
        self.item = item

    def __eq__(self, other):
        if isinstance(other, self.__class__):
            return self.__dict__ == other.__dict__
        else:
            return False

    def __ne__(self, other):
        return not self.__eq__(other)

你知道有什么更优雅的方法吗?你知道用上面的方法比较__dict__有什么特别的缺点吗?

注意:稍微澄清一下——当__eq____ne__未定义时,您会发现这种行为:

>>> a = Foo(1)
>>> b = Foo(1)
>>> a is b
False
>>> a == b
False

也就是说,a == b的计算结果是False,因为它确实运行a is b,一个身份测试(即,ab是同一个对象吗?)。

当定义了__eq____ne__时,您会发现这种行为(这是我们所追求的行为):

>>> a = Foo(1)
>>> b = Foo(1)
>>> a is b
False
>>> a == b
True

Tags: 方法selffalsereturnfooisdef运算符
3条回答

你所描述的就是我一直以来所做的。因为它是完全通用的,所以您总是可以将该功能分解为一个mixin类,并在需要该功能的类中继承它。

class CommonEqualityMixin(object):

    def __eq__(self, other):
        return (isinstance(other, self.__class__)
            and self.__dict__ == other.__dict__)

    def __ne__(self, other):
        return not self.__eq__(other)

class Foo(CommonEqualityMixin):

    def __init__(self, item):
        self.item = item

你需要小心继承:

>>> class Foo:
    def __eq__(self, other):
        if isinstance(other, self.__class__):
            return self.__dict__ == other.__dict__
        else:
            return False

>>> class Bar(Foo):pass

>>> b = Bar()
>>> f = Foo()
>>> f == b
True
>>> b == f
False

更严格地检查类型,例如:

def __eq__(self, other):
    if type(other) is type(self):
        return self.__dict__ == other.__dict__
    return False

除此之外,你的方法也会很好,这就是特殊方法的作用。

考虑这个简单的问题:

class Number:

    def __init__(self, number):
        self.number = number


n1 = Number(1)
n2 = Number(1)

n1 == n2 # False -- oops

因此,Python默认使用对象标识符进行比较操作:

id(n1) # 140400634555856
id(n2) # 140400634555920

重写__eq__函数似乎可以解决问题:

def __eq__(self, other):
    """Overrides the default implementation"""
    if isinstance(other, Number):
        return self.number == other.number
    return False


n1 == n2 # True
n1 != n2 # True in Python 2 -- oops, False in Python 3

Python 2中,也要记住重写__ne__函数以及documentation状态:

There are no implied relationships among the comparison operators. The truth of x==y does not imply that x!=y is false. Accordingly, when defining __eq__(), one should also define __ne__() so that the operators will behave as expected.

def __ne__(self, other):
    """Overrides the default implementation (unnecessary in Python 3)"""
    return not self.__eq__(other)


n1 == n2 # True
n1 != n2 # False

Python 3中,由于documentation声明:

By default, __ne__() delegates to __eq__() and inverts the result unless it is NotImplemented. There are no other implied relationships among the comparison operators, for example, the truth of (x<y or x==y) does not imply x<=y.

但这并不能解决我们所有的问题。让我们添加一个子类:

class SubNumber(Number):
    pass


n3 = SubNumber(1)

n1 == n3 # False for classic-style classes -- oops, True for new-style classes
n3 == n1 # True
n1 != n3 # True for classic-style classes -- oops, False for new-style classes
n3 != n1 # False

注意:Python 2有两种类型:

  • classic-style(或旧样式)类,它们不从object继承,声明为class A:class A():class A(B):,其中B是一个经典样式类;

  • new-style类,继承自object,声明为class A(object)class A(B):,其中B是一个新样式的类。Python 3只有声明为class A:class A(object):class A(B):的新样式类。

对于经典样式的类,比较操作总是调用第一个操作数的方法,而对于新样式的类,它总是调用子类操作数的方法regardless of the order of the operands

所以在这里,如果Number是一个典型的样式类:

  • n1 == n3调用n1.__eq__
  • n3 == n1调用n3.__eq__
  • n1 != n3调用n1.__ne__
  • n3 != n1调用n3.__ne__

如果Number是一个新样式的类:

  • n1 == n3n3 == n1都调用n3.__eq__
  • n1 != n3n3 != n1都调用n3.__ne__

要解决Python 2经典样式类的==!=运算符的不可交换性问题,当不支持操作数类型时,__eq____ne__方法应返回NotImplemented值。documentationNotImplemented值定义为:

Numeric methods and rich comparison methods may return this value if they do not implement the operation for the operands provided. (The interpreter will then try the reflected operation, or some other fallback, depending on the operator.) Its truth value is true.

在这种情况下,运算符将比较操作委托给其他操作数的反射方法documentation将反射方法定义为:

There are no swapped-argument versions of these methods (to be used when the left argument does not support the operation but the right argument does); rather, __lt__() and __gt__() are each other’s reflection, __le__() and __ge__() are each other’s reflection, and __eq__() and __ne__() are their own reflection.

结果如下:

def __eq__(self, other):
    """Overrides the default implementation"""
    if isinstance(other, Number):
        return self.number == other.number
    return NotImplemented

def __ne__(self, other):
    """Overrides the default implementation (unnecessary in Python 3)"""
    x = self.__eq__(other)
    if x is not NotImplemented:
        return not x
    return NotImplemented

如果==!=运算符的交换性在操作数为不相关类型(没有继承)时是需要的,那么返回NotImplemented值而不是False是正确的,即使对于新样式的类也是如此。

我们到了吗?不完全是。我们有多少个唯一的号码?

len(set([n1, n2, n3])) # 3 -- oops

集合使用对象的散列,默认情况下Python返回对象标识符的散列。让我们尝试覆盖它:

def __hash__(self):
    """Overrides the default implementation"""
    return hash(tuple(sorted(self.__dict__.items())))

len(set([n1, n2, n3])) # 1

最终结果如下(我在最后添加了一些断言以供验证):

class Number:

    def __init__(self, number):
        self.number = number

    def __eq__(self, other):
        """Overrides the default implementation"""
        if isinstance(other, Number):
            return self.number == other.number
        return NotImplemented

    def __ne__(self, other):
        """Overrides the default implementation (unnecessary in Python 3)"""
        x = self.__eq__(other)
        if x is not NotImplemented:
            return not x
        return NotImplemented

    def __hash__(self):
        """Overrides the default implementation"""
        return hash(tuple(sorted(self.__dict__.items())))


class SubNumber(Number):
    pass


n1 = Number(1)
n2 = Number(1)
n3 = SubNumber(1)
n4 = SubNumber(4)

assert n1 == n2
assert n2 == n1
assert not n1 != n2
assert not n2 != n1

assert n1 == n3
assert n3 == n1
assert not n1 != n3
assert not n3 != n1

assert not n1 == n4
assert not n4 == n1
assert n1 != n4
assert n4 != n1

assert len(set([n1, n2, n3, ])) == 1
assert len(set([n1, n2, n3, n4])) == 2

相关问题 更多 >