如何比较数据类的相等性努比·恩达雷(bool(a==b)引发值错误)?

2024-03-28 15:36:51 发布

您现在位置:Python中文网/ 问答频道 /正文

如果我创建一个包含Numpy ndarray的Python数据类,就不能再使用自动生成的__eq__。在

^{1}$

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

这是因为ndarray.__eq__有时通过将a[0]与{}进行比较,从而返回一个ndarray的真值,以此类推,以此类推。这是相当复杂和不直观的,实际上只会在数组的形状不同,或具有不同的值或其他值时引发错误。在

如何安全地比较@dataclass持有Numpy数组的es?在


@dataclass__eq__实现是使用eval()生成的。它的源在stacktrace中丢失,无法使用inspect查看,但它实际上使用了一个元组比较,它调用bool(foo)。在

^{pr2}$

节选:

  3          12 LOAD_FAST                0 (self)
             14 LOAD_ATTR                1 (foo)
             16 LOAD_FAST                0 (self)
             18 LOAD_ATTR                2 (bar)
             20 BUILD_TUPLE              2
             22 LOAD_FAST                1 (other)
             24 LOAD_ATTR                1 (foo)
             26 LOAD_FAST                1 (other)
             28 LOAD_ATTR                2 (bar)
             30 BUILD_TUPLE              2
             32 COMPARE_OP               2 (==)
             34 RETURN_VALUE

Tags: 数据buildselfnumpyfoobarload数组
1条回答
网友
1楼 · 发布于 2024-03-28 15:36:51

解决方案是放入您自己的__eq__方法并设置eq=False,这样数据类就不会生成自己的(虽然最后一步检查docs是不必要的,但我认为显式还是很好的)。在

import numpy as np

def array_eq(arr1, arr2):
    return (isinstance(arr1, np.ndarray) and
            isinstance(arr2, np.ndarray) and
            arr1.shape == arr2.shape and
            (arr1 == arr2).all())

@dataclass(eq=False)
class Instr:

    foo: np.ndarray
    bar: np.ndarray

    def __eq__(self, other):
        if not isinstance(other, Instr):
            return NotImplemented
        return array_eq(self.foo, other.foo) and array_eq(self.bar, other.bar)

编辑

通用数据类的通用快速解决方案,其中一些值是numpy数组,而另一些不是numpy数组

^{pr2}$

相关问题 更多 >