np.isin测试Numpy数组是否包含考虑顺序的给定行

2024-04-20 12:17:12 发布

您现在位置:Python中文网/ 问答频道 /正文

我使用下面的行来查找b的行是否在a

 a[np.all(np.isin(a[:, 0:3], b[:, 0:3]), axis=1), 3]

数组沿着axis=1有更多的条目,我只比较前3个条目并返回a的第四个条目(idx=3)

我意识到的可能错误是,没有考虑条目的顺序。因此,下面是ab的示例:

a = np.array([[...],
              [1, 2, 3, 1000],
              [2, 1, 3, 2000],
              [...]])

b = np.array([[1, 2, 3]])

将返回[1000, 2000],而不是只返回[1000]

<>我怎么能考虑行的顺序呢?


Tags: 目的示例顺序错误np条目数组all
1条回答
网友
1楼 · 发布于 2024-04-20 12:17:12

对于较小的b(小于100行),请尝试以下方法:

a[(a[:, :3] == b[:, None]).all(axis=-1).any(axis=0)]

例如:

a = np.array([[1, 0, 5, 0],
              [1, 2, 3, 1000],
              [2, 1, 3, 2000],
              [0, 0, 1, 1]])

b = np.array([[1, 2, 3], [0, 0, 1]])

>>> a[(a[:, :3] == b[:, None]).all(axis=-1).any(axis=0), 3]
array([1000,    1])

说明:

关键是将a的所有行(前3列)的相等性测试“分发”到b的所有行:

# on the example above

>>> a[:, :3] == b[:, None]
array([[[ True, False, False],
        [ True,  True,  True],  # <  a[1,:3] matches b[0]
        [False, False,  True],
        [False, False, False]],

       [[False,  True, False],
        [False, False, False],
        [False, False, False],
        [ True,  True,  True]]])  # <  a[3, :3] matches b[1]

请注意,这可能很大:形状为(len(b), len(a), 3)

然后,第一个.all(axis=-1)表示我们希望所有整行都匹配:

>>> (a[:, :3] == b[:, None]).all(axis=-1)
array([[False,  True, False, False],
       [False, False, False,  True]])

最后一位.any(axis=0)表示:“匹配b中的任何行”:

>>> (a[:, :3] == b[:, None]).all(axis=-1).any(axis=0)
array([False,  True, False,  True])

即:“a[2, :3]匹配b的一些行以及a[3, :3]

最后,将其用作a中的掩码,并取第3列

性能说明

上述技术将a行与b行的乘积相等。如果ab都有许多行,那么这可能会很慢,并且会占用大量内存

或者,您可以在纯Python中使用set成员身份(不需要调用者可以完成的列子集):

def py_rows_in(a, b):
    z = set(map(tuple, b))
    return [row in z for row in map(tuple, a)]

b的行数超过50~100行时,与上面作为函数编写的np版本相比,这可能更快:

def np_rows_in(a, b):
    return (a == b[:, None]).all(axis=-1).any(axis=0)
import perfplot

fig, axes = plt.subplots(ncols=2, figsize=(16, 5))
plt.subplots_adjust(wspace=.5)
for ax, alen in zip(axes, [100, 10_000]):
    a = np.random.randint(0, 20, (alen, 4))
    plt.sca(ax)
    ax.set_title(f'a: {a.shape[0]:_} rows')
    perfplot.show(
        setup=lambda n: np.random.randint(0, 20, (n, 3)),
        kernels=[
            lambda b: np_rows_in(a[:, :3], b),
            lambda b: py_rows_in(a[:, :3], b),
        ],
        labels=['np_rows_in', 'py_rows_in'],
        n_range=[2 ** k for k in range(10)],
        xlabel='len(b)',
    )
plt.show()

comparative performance

相关问题 更多 >