测试Numpy数组是否包含指定行

75 投票
6 回答
66795 浏览
提问于 2025-04-17 15:12

有没有一种既符合Python风格又高效的方法来检查一个Numpy数组中是否至少包含一行特定的数据?这里的“高效”是指,一旦找到第一行匹配的数据就停止,而不是继续遍历整个数组,即使已经找到了结果。

在Python的普通数组中,可以很简单地用 if row in array: 来实现这个功能,但在Numpy数组中,这种方法的效果并不像我预期的那样,下面会有例子说明。

在Python数组中:

>>> a = [[1,2],[10,20],[100,200]]
>>> [1,2] in a
True
>>> [1,20] in a
False

但是Numpy数组的结果却不同,看起来有点奇怪。(ndarray__contains__ 方法似乎没有文档说明。)

>>> a = np.array([[1,2],[10,20],[100,200]])
>>> np.array([1,2]) in a
True
>>> np.array([1,20]) in a
True
>>> np.array([1,42]) in a
True
>>> np.array([42,1]) in a
False

6 个回答

9

我认为

equal([1,2], a).all(axis=1)   # also,  ([1,2]==a).all(axis=1)
# array([ True, False, False], dtype=bool)

会列出所有匹配的行。正如Jamie提到的,要知道是否至少有一行符合条件,可以使用 any

equal([1,2], a).all(axis=1).any()
# True

顺便说一下:
我猜 in(和 __contains__)的用法和上面的一样,只不过是用 any 代替了 all

69

你可以使用 .tolist() 方法来转换数据。

>>> a = np.array([[1,2],[10,20],[100,200]])
>>> [1,2] in a.tolist()
True
>>> [1,20] in a.tolist()
False
>>> [1,20] in a.tolist()
False
>>> [1,42] in a.tolist()
False
>>> [42,1] in a.tolist()
False

或者使用视图来处理数据:

>>> any((a[:]==[1,2]).all(1))
True
>>> any((a[:]==[1,20]).all(1))
False

或者通过生成 numpy 列表(这可能会非常慢):

any(([1,2] == x).all() for x in a)     # stops on first occurrence 

或者使用 numpy 的逻辑函数:

any(np.equal(a,[1,2]).all(1))

如果你对这些方法进行计时:

import numpy as np
import time

n=300000
a=np.arange(n*3).reshape(n,3)
b=a.tolist()

t1,t2,t3=a[n//100][0],a[n//2][0],a[-10][0]

tests=[ ('early hit',[t1, t1+1, t1+2]),
        ('middle hit',[t2,t2+1,t2+2]),
        ('late hit', [t3,t3+1,t3+2]),
        ('miss',[0,2,0])]

fmt='\t{:20}{:.5f} seconds and is {}'     

for test, tgt in tests:
    print('\n{}: {} in {:,} elements:'.format(test,tgt,n))

    name='view'
    t1=time.time()
    result=(a[...]==tgt).all(1).any()
    t2=time.time()
    print(fmt.format(name,t2-t1,result))

    name='python list'
    t1=time.time()
    result = True if tgt in b else False
    t2=time.time()
    print(fmt.format(name,t2-t1,result))

    name='gen over numpy'
    t1=time.time()
    result=any((tgt == x).all() for x in a)
    t2=time.time()
    print(fmt.format(name,t2-t1,result))

    name='logic equal'
    t1=time.time()
    np.equal(a,tgt).all(1).any()
    t2=time.time()
    print(fmt.format(name,t2-t1,result))

你会发现无论成功与否,numpy 的方法在搜索数组时速度是差不多的。而 Python 的 in 操作符在找到结果时可能会快很多,但如果你需要遍历整个数组,使用生成器就不太好了。

这是对于一个 300,000 x 3 元素数组的结果:

early hit: [9000, 9001, 9002] in 300,000 elements:
    view                0.01002 seconds and is True
    python list         0.00305 seconds and is True
    gen over numpy      0.06470 seconds and is True
    logic equal         0.00909 seconds and is True

middle hit: [450000, 450001, 450002] in 300,000 elements:
    view                0.00915 seconds and is True
    python list         0.15458 seconds and is True
    gen over numpy      3.24386 seconds and is True
    logic equal         0.00937 seconds and is True

late hit: [899970, 899971, 899972] in 300,000 elements:
    view                0.00936 seconds and is True
    python list         0.30604 seconds and is True
    gen over numpy      6.47660 seconds and is True
    logic equal         0.00965 seconds and is True

miss: [0, 2, 0] in 300,000 elements:
    view                0.00936 seconds and is False
    python list         0.01287 seconds and is False
    gen over numpy      6.49190 seconds and is False
    logic equal         0.00965 seconds and is False

还有一个 3,000,000 x 3 的数组:

early hit: [90000, 90001, 90002] in 3,000,000 elements:
    view                0.10128 seconds and is True
    python list         0.02982 seconds and is True
    gen over numpy      0.66057 seconds and is True
    logic equal         0.09128 seconds and is True

middle hit: [4500000, 4500001, 4500002] in 3,000,000 elements:
    view                0.09331 seconds and is True
    python list         1.48180 seconds and is True
    gen over numpy      32.69874 seconds and is True
    logic equal         0.09438 seconds and is True

late hit: [8999970, 8999971, 8999972] in 3,000,000 elements:
    view                0.09868 seconds and is True
    python list         3.01236 seconds and is True
    gen over numpy      65.15087 seconds and is True
    logic equal         0.09591 seconds and is True

miss: [0, 2, 0] in 3,000,000 elements:
    view                0.09588 seconds and is False
    python list         0.12904 seconds and is False
    gen over numpy      64.46789 seconds and is False
    logic equal         0.09671 seconds and is False

这似乎表明 np.equal 是使用 numpy 处理这个问题的最快方法……

24

Numpy的 __contains__ 方法在写这段话的时候,其实是 (a == b).any()。这说起来有点复杂,只有在 b 是一个单一的数值时才算正确(这有点麻烦,不过我觉得——从1.7版本开始是这样——这个方法 (a == b).all(np.arange(a.ndim - b.ndim, a.ndim)).any() 对于所有 ab 的维度组合都适用)...

补充一下:这并不一定是涉及广播时的预期结果。还有人可能会说,它应该像 np.in1d 那样分别处理 a 中的每个项目。我不太确定应该有一个明确的方式来处理这个问题。

现在你希望numpy在找到第一个匹配项时就停止。根据我所知,目前并没有这样的功能。这很困难,因为numpy主要是基于ufuncs(通用函数),它们会对整个数组执行相同的操作。虽然numpy确实优化了这类操作,但实际上只有在被处理的数组已经是布尔数组时(比如 np.ones(10, dtype=bool).any())才有效。

否则,它就需要一个专门的 __contains__ 函数,但这个函数并不存在。这听起来可能有点奇怪,但你要记住,numpy支持多种数据类型,并且有更复杂的机制来选择正确的数据类型和相应的函数。因此,ufunc的机制无法做到这一点,专门实现 __contains__ 之类的功能并不简单,因为涉及到数据类型的问题。

当然,你可以用python自己写这个功能,或者如果你已经知道你的数据类型,自己用Cython/C写也是非常简单的。


话说回来,通常情况下,使用基于排序的方法会更好。虽然这有点繁琐,因为没有类似于 searchsortedlexsort,但它是可行的(如果你愿意的话,也可以利用 scipy.spatial.cKDTree)。这假设你只想沿着最后一个轴进行比较:

# Unfortunatly you need to use structured arrays:
sorted = np.ascontiguousarray(a).view([('', a.dtype)] * a.shape[-1]).ravel()

# Actually at this point, you can also use np.in1d, if you already have many b
# then that is even better.

sorted.sort()

b_comp = np.ascontiguousarray(b).view(sorted.dtype)
ind = sorted.searchsorted(b_comp)

result = sorted[ind] == b_comp

这同样适用于数组 b,如果你保持排序后的数组,针对 b 中的单个值(行)进行处理时效果会更好,尤其是当 a 保持不变时(否则我会在将其视为记录数组后直接使用 np.in1d)。重要:你必须使用 np.ascontiguousarray 来确保安全。通常这不会有什么影响,但如果有的话,可能会导致严重的潜在错误。

撰写回答