检查numpy.array相等的最佳方法是什么?

153 投票
9 回答
101176 浏览
提问于 2025-04-16 01:39

我想为我的应用程序写一些单元测试,我需要比较两个数组。因为 array.__eq__ 会返回一个新数组(所以 TestCase.assertEqual 会失败),那么检查两个数组是否相等的最好方法是什么呢?

目前我在用

self.assertTrue((arr1 == arr2).all())

但我其实不太喜欢这个方法。

9 个回答

25

我发现用 self.assertEqual(arr1.tolist(), arr2.tolist()) 这个方法来比较数组,在使用unittest的时候是最简单的。

我同意这不是最优雅的解决方案,可能也不是最快的,但它和你其他的测试用例比较起来,方式是比较一致的。你能得到所有unittest的错误描述,而且实现起来真的很简单。

35

我觉得 (arr1 == arr2).all() 看起来挺不错的。不过你也可以用:

numpy.allclose(arr1, arr2)

但这和之前的写法不完全一样。

还有一种替代方法,几乎和你的例子一样:

numpy.alltrue(arr1 == arr2)

需要注意的是,scipy.array 实际上是一个引用的 numpy.array。这让查找相关文档变得更简单。

169

可以看看 numpy.testing 里的断言函数,比如说:

assert_array_equal

对于浮点数数组来说,直接比较是否相等可能会失败,这时候 assert_almost_equal 更可靠一些。

更新

在之前的几个版本中,numpy 新增了 assert_allclose 这个函数,现在我最喜欢用它,因为它可以让我们同时指定绝对误差和相对误差,而且不需要像之前那样用小数点来判断接近程度。

撰写回答