通过meshgrid寻找成对numy数组的所有组合

2024-04-29 06:21:48 发布

您现在位置:Python中文网/ 问答频道 /正文

假设我有一个由值对组成的numpy数组。我想在不把它们分开的情况下找到所有的组合。特别是,我希望有一个numpy.meshgrid的解决方案。在

假设一个数组的构造如下:

ab = np.array([[1,10], [2,20], [3,30], [4,40]])

那么我想要的输出是

^{pr2}$

输出可以是np.array,也可以是tuple(我可以在以后进行相应的转换)。请注意,在我的结果中,重复项是如何被省略的,忽略了我的伴侣的顺序(如果[[1,10], [2,20]]已经存在,我不希望在输出中出现{})。对于实际情况,ab的大小为30000,所以速度是另一个问题。在

所以我一开始就试着用网状网格。 对于单个值的简单情况,这很容易做到(但是,仍然有重复项):

a = np.array([1,2,3,4])
mesh = np.array(np.meshgrid(a,a)).T.reshape(-1,2)
>>> out: [[1 1]
          [1 2]
          [1 3]
          [1 4]
          [2 1]
          [...]
          [4 4]]

但对于我的搭档,我的尝试

mesh = np.array(np.meshgrid(ab,ab)).T

给了我

[[[ 1  1]
  [ 1 10]
  [ 1  2]
  [ 1 20]
  [ 1  3]
  [ 1 30]
  [ 1  4]
  [ 1 40]]

 [[10  1]
  [10 10]
  [10  2]
  [10 20]
...    
  [40  3]
  [40 30]
  [40  4]
  [40 40]]]

换句话说:meshgrid分解了我的对。我想这个解决方案很快就要解决了,但我自己想不出来。感谢任何帮助,谢谢!在


Tags: numpyab顺序np情况数组解决方案array
1条回答
网友
1楼 · 发布于 2024-04-29 06:21:48

不要认为meshgrid会工作,因为它会创建所有可能的组合(稍后不过滤掉)。为了解决这个问题,可以提出两种方法。在

方法1

我们可以得到那些没有重复的成对组合的行索引,然后简单地索引成行以获得所需的输出,如下-

In [99]: r,c = np.triu_indices(len(ab),1)

In [100]: np.hstack(( ab[r], ab[c] ))
Out[100]: 
array([[ 1, 10,  2, 20],
       [ 1, 10,  3, 30],
       [ 1, 10,  4, 40],
       [ 2, 20,  3, 30],
       [ 2, 20,  4, 40],
       [ 3, 30,  4, 40]])

要以3D数组的形式获得所需的输出,请沿第二个轴堆叠-

^{pr2}$

作为函数:

def pairwise_combs1(ab):
    r,c = np.triu_indices(len(ab),1)
    return np.stack(( ab[r], ab[c] ), axis=1)

方法2另一种以slicing和{}为目标的内存效率和性能-

def pairwise_combs2(ab):
    n = len(ab)
    N = n*(n-1)//2
    out = np.empty((N,2,2),dtype=ab.dtype)
    idx = np.concatenate(( [0], np.arange(n-1,0,-1).cumsum() ))
    start, stop = idx[:-1], idx[1:]
    for j,i in enumerate(range(n-1)):
        out[start[j]:stop[j],0] = ab[j]
        out[start[j]:stop[j],1] = ab[j+1:]
    return out

运行时测试

In [166]: ab = np.random.randint(0,9,(1000,2))

In [167]: %timeit pairwise_combs1(ab)
10 loops, best of 3: 20 ms per loop

In [168]: %timeit pairwise_combs2(ab)
100 loops, best of 3: 6.25 ms per loop

In [169]: np.allclose(pairwise_combs1(ab), pairwise_combs2(ab))
Out[169]: True

相关问题 更多 >