如果我创建一个包含Numpy ndarray的Python数据类,就不能再使用自动生成的__eq__
。在
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
这是因为ndarray.__eq__
有时通过将a[0]
与{ndarray
的真值,以此类推,以此类推。这是相当复杂和不直观的,实际上只会在数组的形状不同,或具有不同的值或其他值时引发错误。在
如何安全地比较@dataclass
持有Numpy数组的es?在
@dataclass
的__eq__
实现是使用eval()
生成的。它的源在stacktrace中丢失,无法使用inspect
查看,但它实际上使用了一个元组比较,它调用bool(foo)。在
节选:
3 12 LOAD_FAST 0 (self) 14 LOAD_ATTR 1 (foo) 16 LOAD_FAST 0 (self) 18 LOAD_ATTR 2 (bar) 20 BUILD_TUPLE 2 22 LOAD_FAST 1 (other) 24 LOAD_ATTR 1 (foo) 26 LOAD_FAST 1 (other) 28 LOAD_ATTR 2 (bar) 30 BUILD_TUPLE 2 32 COMPARE_OP 2 (==) 34 RETURN_VALUE
解决方案是放入您自己的
__eq__
方法并设置eq=False
,这样数据类就不会生成自己的(虽然最后一步检查docs是不必要的,但我认为显式还是很好的)。在编辑
通用数据类的通用快速解决方案,其中一些值是numpy数组,而另一些不是numpy数组
^{pr2}$相关问题 更多 >
编程相关推荐