测试numpy数组中是否所有值相等

22 投票
3 回答
54694 浏览
提问于 2025-04-16 23:50

我有一个一维的numpy数组 c,它应该装着 a + b 的结果。我首先在一个设备上用 PyOpenCL 执行 a + b

我想快速用 numpy 的切片功能来检查结果数组 c 是否正确。

这是我现在的代码:

def python_kernel(a, b, c):
    temp = a + b
    if temp[:] != c[:]:
        print "Error"
    else:
        print "Success!"

但是我遇到了一个错误:

ValueError: 数组的真值在有多个元素时是模糊的。请使用 a.any() 或 a.all()。

不过,似乎 a.anya.all 只是用来判断值是否不为0。

如果我想测试 numpy 数组 temp 中的所有元素是否都等于 numpy 数组 c 中的每一个值,我该怎么做呢?

3 个回答

7

你可以在比较的结果上使用 any 函数,比如这样写:if np.any(a+b != c):,或者换个写法:if np.all(a+b == c):。这里,a+b != c 会生成一个布尔数组,也就是一个只包含真(True)和假(False)的数组,然后 any 会检查这个数组,看看里面有没有哪个值是 True

>>> import numpy as np
>>> a = np.array([1,2,3])
>>> b = np.array([4,5,2])
>>> c = a+b
>>> c
array([5, 7, 5]) # <---- numeric, so any/all not useful
>>> a+b == c
array([ True,  True,  True], dtype=bool) # <---- BOOLEAN result, not numeric
>>> all(a+b == c)
True

不过,正如上面所说的,Amber 的解决方案 可能会更快,因为它不需要生成整个布尔结果数组。

15

np.allclose 是一个不错的选择,特别是当你使用 np.array 这种浮点数类型的数据时。np.array_equal 有时候可能会出现问题,不能正确工作。举个例子:

import numpy as np
def get_weights_array(n_recs):
    step = - 0.5 / n_recs
    stop = 0.5
    return np.arange(1, stop, step)

a = get_weights_array(5)
b = np.array([1.0, 0.9, 0.8, 0.7, 0.6])

结果:

>>> a
array([ 1. ,  0.9,  0.8,  0.7,  0.6])
>>> b
array([ 1. ,  0.9,  0.8,  0.7,  0.6])
>>> np.array_equal(a, b)
False
>>> np.allclose(a, b)
True

>>> import sys
>>> sys.version
'2.7.3 (default, Apr 10 2013, 05:13:16) \n[GCC 4.7.2]'
>>> np.version.version
'1.6.2'
54

为什么不直接用 numpy.array_equal(a1, a2)[文档] 这个NumPy的函数呢?

撰写回答