快速找到集合中的numpy向量

2024-04-27 03:06:23 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个numpy数组,例如:

a = np.array([[1,2],
              [3,4],
              [6,4],
              [5,3],
              [3,5]])

我也有一套

^{pr2}$

我想找到集合b中向量的索引,这里是

[0, 2]

但我使用for循环来实现这一点,有没有一种简便的方法来避免for循环呢? 我使用的for循环方法:

record = []
for i in range(a.shape[0]):
    if (a[i, 0], a[i, 1]) in b:
        record.append(i)

Tags: 方法innumpyforifnprange数组
3条回答

首先,将集合转换为NumPy数组-

b_arr = np.array(list(b))

然后,基于^{},您将有三种方法。让我们用第二种方法来提高效率-

^{pr2}$

样本运行-

In [89]: a
Out[89]: 
array([[1, 2],
       [3, 4],
       [6, 4],
       [5, 3],
       [3, 5]])

In [90]: b
Out[90]: {(1, 2), (6, 4), (9, 9)}

In [91]: b_arr = np.array(list(b))

In [92]: dims = np.maximum(a.max(0),b_arr.max(0)) + 1
    ...: a1D = np.ravel_multi_index(a.T,dims)
    ...: b1D = np.ravel_multi_index(b_arr.T,dims)    
    ...: out = np.flatnonzero(np.in1d(a1D,b1D))
    ...: 

In [93]: out
Out[93]: array([0, 2])

您可以使用过滤器:

In [8]: a = np.array([[1,2],
              [3,4],
              [6,4],
              [5,3],
              [3,5]])

In [9]: b = {(1,2),(6,4)}

In [10]: filter(lambda x: tuple(a[x]) in b, range(len(a)))
Out[10]: [0, 2]

作为参考,直接列表理解(循环)答案:

In [108]: [i for i,v in enumerate(a) if tuple(v) in b]
Out[108]: [0, 2]

基本上与filter方法相同:

^{pr2}$

但这是一个玩具的例子,所以计时没有意义。在

如果a还不是数组,由于创建数组的开销,这些列表方法将比数组方法快。在

有一些numpy集操作,但它们适用于1d阵列。我们可以通过将二维阵列转换为一维结构来解决这个问题。在

In [117]: a.view('i,i')
Out[117]: 
array([[(1, 2)],
       [(3, 4)],
       [(6, 4)],
       [(5, 3)],
       [(3, 5)]], 
      dtype=[('f0', '<i4'), ('f1', '<i4')])
In [119]: np.array(list(b),'i,i')
Out[119]: 
array([(1, 2), (6, 4), (9, 9)], 
      dtype=[('f0', '<i4'), ('f1', '<i4')])

有一个使用np.void的版本,但是它更容易记住和使用这个'i,i'数据类型。在

所以这是可行的:

^{4}$

但它比迭代要慢得多:

In [124]: timeit np.nonzero(np.in1d(a.view('i,i'),np.array(list(b),'i,i')))[0]
10000 loops, best of 3: 153 µs per loop

正如其他最近的union问题所讨论的,np.in1d使用了几种策略。一种是基于广播和where。另一个使用uniqueconcatenationsorting和区别。在

一个广播解决方案(是的,很混乱),但比in1d快。在

In [150]: timeit np.nonzero((a[:,:,None,None]==np.array(list(b))[:,:]).any(axis=-1).any(axis=-1).all(axis=-1))[0]
10000 loops, best of 3: 52.2 µs per loop

相关问题 更多 >