通过属性比较对象实例的相等性

347 投票
17 回答
365414 浏览
提问于 2025-04-15 13:22

我有一个类叫做 MyClass,里面有两个成员变量 foobar

class MyClass:
    def __init__(self, foo, bar):
        self.foo = foo
        self.bar = bar

我创建了这个类的两个实例,它们的 foobar 的值是完全一样的:

x = MyClass('foo', 'bar')
y = MyClass('foo', 'bar')

但是,当我比较这两个实例是否相等时,Python却返回 False

>>> x == y
False

我该怎么做才能让Python认为这两个对象是相等的呢?

17 个回答

29

如果你在处理一些你无法从内部修改的类,有一些简单而通用的方法可以做到这一点,而且这些方法不依赖于特定的库:

最简单,但对复杂对象不安全的方法

pickle.dumps(a) == pickle.dumps(b)

pickle 是一个非常常用的 Python 对象序列化库,它几乎可以序列化任何东西。在上面的代码片段中,我将序列化后的 astrbstr 进行了比较。与下一个方法不同,这个方法的优点是可以对自定义类进行类型检查。

最大的麻烦是:由于特定的排序和 [解/编码] 方法,pickle 可能会对相等的对象产生不同的结果,尤其是在处理更复杂的对象时(例如,嵌套自定义类实例的列表),这在一些第三方库中很常见。对于这些情况,我建议使用另一种方法:

全面、安全的方法

你可以写一个递归反射函数,这样就能得到可序列化的对象,然后进行比较。

from collections.abc import Iterable

BASE_TYPES = [str, int, float, bool, type(None)]


def base_typed(obj):
    """Recursive reflection method to convert any object property into a comparable form.
    """
    T = type(obj)
    from_numpy = T.__module__ == 'numpy'

    if T in BASE_TYPES or callable(obj) or (from_numpy and not isinstance(T, Iterable)):
        return obj

    if isinstance(obj, Iterable):
        base_items = [base_typed(item) for item in obj]
        return base_items if from_numpy else T(base_items)

    d = obj if T is dict else obj.__dict__

    return {k: base_typed(v) for k, v in d.items()}


def deep_equals(*args):
    return all(base_typed(args[0]) == base_typed(other) for other in args[1:])

现在无论你的对象是什么,深度相等都能保证有效。

>>> from sklearn.ensemble import RandomForestClassifier
>>>
>>> a = RandomForestClassifier(max_depth=2, random_state=42)
>>> b = RandomForestClassifier(max_depth=2, random_state=42)
>>> 
>>> deep_equals(a, b)
True

可比较的数量也无关紧要。

>>> c = RandomForestClassifier(max_depth=2, random_state=1000)
>>> deep_equals(a, b, c)
False

我使用这个方法的场景是检查在 BDD 测试中多种已经训练过的机器学习模型之间的深度相等。这些模型来自不同的第三方库。显然,像这里其他答案建议的那样实现 __eq__ 对我来说并不是一个选项。

覆盖所有情况

你可能会遇到一种情况,其中一个或多个自定义类在比较时没有 __dict__ 实现。这并不常见,但在 sklearn 的随机森林分类器中确实存在这种情况:<type 'sklearn.tree._tree.Tree'>。对于这些情况,逐个处理 - 例如,具体来说,我决定用一个方法的内容替换受影响类型的内容,这个方法能给我关于实例的代表性信息(在这种情况下是 __getstate__ 方法)。因此,base_typed 中倒数第二行变成了:

d = obj if T is dict else obj.__dict__ if '__dict__' in dir(obj) else obj.__getstate__()

编辑:为了组织更好,我用 return dict_from(obj) 替换了上面的丑陋单行代码。在这里,dict_from 是一个非常通用的反射函数,旨在适应一些比较冷门的库(我在看你,Doc2Vec)。

def isproperty(prop, obj):
    return not callable(getattr(obj, prop)) and not prop.startswith('_')


def dict_from(obj):
    """Converts dict-like objects into dicts
    """
    if isinstance(obj, dict):
        # Dict and subtypes are directly converted
        d = dict(obj)

    elif '__dict__' in dir(obj):
        # Use standard dict representation when available
        d = obj.__dict__

    elif str(type(obj)) == 'sklearn.tree._tree.Tree':
        # Replaces sklearn trees with their state metadata
        d = obj.__getstate__()

    else:
        # Extract non-callable, non-private attributes with reflection
        kv = [(p, getattr(obj, p)) for p in dir(obj) if isproperty(p, obj)]
        d = {k: v for k, v in kv}

    return {k: base_typed(v) for k, v in d.items()}

请注意,上述任何方法都不会对具有相同键值对但顺序不同的对象返回 True,就像:

>>> a = {'foo':[], 'bar':{}}
>>> b = {'bar':{}, 'foo':[]}
>>> pickle.dumps(a) == pickle.dumps(b)
False

但如果你想要这样的效果,你可以在此之前使用 Python 内置的 sorted 方法。

55

你在你的对象中重写了丰富的比较运算符

class MyClass:
 def __lt__(self, other):
      # return comparison
 def __le__(self, other):
      # return comparison
 def __eq__(self, other):
      # return comparison
 def __ne__(self, other):
      # return comparison
 def __gt__(self, other):
      # return comparison
 def __ge__(self, other):
      # return comparison

像这样:

    def __eq__(self, other):
        return self._id == other._id
494

你应该实现一个叫做 __eq__ 的方法:

class MyClass:
    def __init__(self, foo, bar):
        self.foo = foo
        self.bar = bar
        
    def __eq__(self, other): 
        if not isinstance(other, MyClass):
            # don't attempt to compare against unrelated types
            return NotImplemented

        return self.foo == other.foo and self.bar == other.bar

现在它的输出是:

>>> x == y
True

注意,实施 __eq__ 方法会让你的类的实例变得不可哈希,这意味着它们不能被存储在集合和字典中。如果你不是在建模一个不可变的类型(也就是说,如果属性 foobar 在对象的生命周期内可能会改变值),那么建议你就让你的实例保持不可哈希。

如果你是在建模一个不可变的类型,你还应该实现一个数据模型钩子 __hash__

class MyClass:
    ...

    def __hash__(self):
        # necessary for instances to behave sanely in dicts and sets.
        return hash((self.foo, self.bar))

像是遍历 __dict__ 并比较值的这种通用解决方案并不推荐——因为它永远无法真正通用,因为 __dict__ 里面可能包含无法比较或不可哈希的类型。

注意:在 Python 3 之前,你可能需要使用 __cmp__ 而不是 __eq__。使用 Python 2 的用户可能还想实现 __ne__,因为在 Python 2 中,默认的“不等于”行为(也就是反转相等的结果)不会自动创建。

撰写回答