Numpy: 如何检查ndarray中元组的存在性
我在使用numpy数组处理元组时发现了一个奇怪的情况。我想要得到一个布尔值的表格,告诉我数组a
中的哪些元组也存在于数组b
中。通常,我会使用in
或者in1d
来实现这个功能。但它们都不管用,而tuple(a[1]) == b[1,1]
却返回了True
。
我这样填充我的a
和b
:
a = numpy.array([(0,0)(1,1)(2,2)], dtype=tuple)
b = numpy.zeros((3,3), dtype=tuple)
for i in range(0,3):
for j in range(0,3):
b[i,j] = (i,j)
有没有人能告诉我解决这个问题的方法,并解释一下为什么会出现这种情况?
(顺便说一下,我这里用的是python2.7和numpy1.6.2。)
1 个回答
6
为什么这样不行
简单来说,numpy中 array.__contains__()
的实现似乎有问题。在Python中,in
操作符在后台会调用 __contains__()
。
这意味着 a in b
和 b.__contains__(a)
是一样的。
我在一个交互式环境中加载了你的数组,并尝试了以下操作:
>>> b[:,0]
array([(0, 0), (1, 0), (2, 0)], dtype=object)
>>> (0,0) in b[:,0] # we expect it to be true
False
>>> (0,0) in list(b[:,0]) # this shouldn't be different from the above but it is
True
>>>
怎么解决这个问题
我觉得你的列表推导式可能不太对,因为 a[x]
是一个元组,而 b[:,:]
是一个二维矩阵,所以它们当然不相等。不过我猜你是想用 in
而不是 ==
。如果我理解错了,请纠正我。
第一步是把 b
从一个二维数组转换成一维数组,这样我们就可以线性地遍历它,并把它转换成一个列表,以避免numpy中 array.__contains()
的问题,像这样:
bb = list(b.reshape(b.size))
或者,更好的是,把它变成一个 set
,因为元组是不可变的,而在集合中检查 in
的时间复杂度是 O(1),而列表是 O(n)。
>>> bb = set(b.reshape(b.size))
>>> print bb
set([(0, 1), (1, 2), (0, 0), (2, 1), (1, 1), (2, 0), (2, 2), (1, 0), (0, 2)])
>>>
接下来,我们简单地使用列表推导式来生成布尔值的表格。
>>> truth_table = [tuple(aa) in bb for aa in a]
>>> print truth_table
[True, True, True]
>>>
完整代码:
def contained(a,b):
bb = set(b.flatten())
return [tuple(aa) in bb for aa in a]