Numpy: 如何检查ndarray中元组的存在性

3 投票
1 回答
4339 浏览
提问于 2025-04-17 16:58

我在使用numpy数组处理元组时发现了一个奇怪的情况。我想要得到一个布尔值的表格,告诉我数组a中的哪些元组也存在于数组b中。通常,我会使用in或者in1d来实现这个功能。但它们都不管用,而tuple(a[1]) == b[1,1]却返回了True

我这样填充我的ab

a = numpy.array([(0,0)(1,1)(2,2)], dtype=tuple)

b = numpy.zeros((3,3), dtype=tuple)
for i in range(0,3):
    for j in range(0,3):
        b[i,j] = (i,j)

有没有人能告诉我解决这个问题的方法,并解释一下为什么会出现这种情况?

(顺便说一下,我这里用的是python2.7和numpy1.6.2。)

1 个回答

6

为什么这样不行

简单来说,numpy中 array.__contains__() 的实现似乎有问题。在Python中,in 操作符在后台会调用 __contains__()

这意味着 a in bb.__contains__(a) 是一样的。

我在一个交互式环境中加载了你的数组,并尝试了以下操作:

>>> b[:,0]
array([(0, 0), (1, 0), (2, 0)], dtype=object)
>>> (0,0) in b[:,0] # we expect it to be true
False
>>> (0,0) in list(b[:,0]) # this shouldn't be different from the above but it is
True
>>> 

怎么解决这个问题

我觉得你的列表推导式可能不太对,因为 a[x] 是一个元组,而 b[:,:] 是一个二维矩阵,所以它们当然不相等。不过我猜你是想用 in 而不是 ==。如果我理解错了,请纠正我。

第一步是把 b 从一个二维数组转换成一维数组,这样我们就可以线性地遍历它,并把它转换成一个列表,以避免numpy中 array.__contains() 的问题,像这样:

bb = list(b.reshape(b.size))

或者,更好的是,把它变成一个 set,因为元组是不可变的,而在集合中检查 in 的时间复杂度是 O(1),而列表是 O(n)。

>>> bb = set(b.reshape(b.size))
>>> print bb
set([(0, 1), (1, 2), (0, 0), (2, 1), (1, 1), (2, 0), (2, 2), (1, 0), (0, 2)])
>>> 

接下来,我们简单地使用列表推导式来生成布尔值的表格。

>>> truth_table = [tuple(aa) in bb for aa in a]
>>> print truth_table
[True, True, True]
>>> 

完整代码:

def contained(a,b):
    bb = set(b.flatten())
    return [tuple(aa) in bb for aa in a]

撰写回答