Python - 2D Numpy数组的交集

3 投票
6 回答
15375 浏览
提问于 2025-04-18 11:29

我正在拼命寻找一种高效的方法来检查两个二维的numpy数组是否有交集。

我有两个数组,每个数组里面有任意数量的二维数组,像这样:

A=np.array([[2,3,4],[5,6,7],[8,9,10]])
B=np.array([[5,6,7],[1,3,4]])
C=np.array([[1,2,3],[6,6,7],[10,8,9]])

我只需要判断如果有至少一个向量和另一个数组中的向量相交,就返回True,否则返回False。所以结果应该是这样的:

f(A,B)  -> True
f(A,C)  -> False

我对Python还不是很熟悉,最开始我用Python的列表写了我的程序,虽然能用,但效率非常低。程序运行起来要花几天时间,所以我现在在尝试用numpy.array来解决这个问题,但这些数组真的不太好处理。

这里是我程序的一些背景和用Python列表的解决方案:

我做的事情有点像在三维空间中进行自避随机游走。http://en.wikipedia.org/wiki/Self-avoiding_walk。但我不是随便走动,希望能达到一个理想的长度(比如我想要由1000个珠子组成的链),而是这样做:

我创建一个“平坦”的链,长度为N:

X=[]
for i in range(0,N+1):
    X.append((i,0,0))

然后我把这个平坦的链折叠起来:

  1. 随机选择一个元素(“枢轴元素”)
  2. 随机选择一个方向(要么是枢轴元素左边的所有元素,要么是右边的)
  3. 随机选择一个空间中的旋转方式(3个轴 * 3种可能的旋转:90°、180°、270°)
  4. 用选择的旋转方式旋转所选方向的所有元素
  5. 检查新选择的方向的元素是否与另一方向相交
  6. 如果没有交集 -> 接受新的配置,否则 -> 保留旧的链。

步骤1到6需要重复很多次(例如,对于长度为1000的链,大约需要5000次),所以这些步骤必须高效完成。我基于列表的解决方案如下:

def PivotFold(chain):
randPiv=random.randint(1,N)  #Chooses a random pivotelement, N is the Chainlength
Pivot=chain[randPiv]  #get that pivotelement
C=[]  #C is going to be a shifted copy of the chain
intersect=False
for j in range (0,N+1):   # Here i shift the hole chain to get the pivotelement to the origin, so i can use simple rotations around the origin
    C.append((chain[j][0]-Pivot[0],chain[j][1]-Pivot[1],chain[j][2]-Pivot[2]))
rotRand=random.randint(1,18)  # rotRand is used to choose a direction and a Rotation (2 possible direction * 9 rotations = 18 possibilitys)
#Rotations around Z-Axis
if rotRand==1:
    for j in range (randPiv,N+1):
        C[j]=(-C[j][1],C[j][0],C[j][2])
        if C[0:randPiv].__contains__(C[j])==True:
            intersect=True
            break
elif rotRand==2:
    for j in range (randPiv,N+1):
        C[j]=(C[j][1],-C[j][0],C[j][2])
        if C[0:randPiv].__contains__(C[j])==True:
            intersect=True
            break
...etc
if intersect==False: # return C if there was no intersection in C
    Shizz=C
else:
    Shizz=chain
return Shizz

函数PivotFold(chain)会在最初的平坦链X上使用很多次。这个函数写得比较简单,所以也许你有一些改进的建议^^ 我觉得使用numpy数组会很好,因为我可以高效地移动和旋转整个链,而不需要遍历所有元素……

6 个回答

0

我觉得你想要的是,如果两个数组有相同的子数组,就返回真!你可以使用这个:

def(A,B):
 for i in A:
  for j in B:
   if i==j
   return True
 return False 
1

这个方法应该会快很多,因为它的复杂度不是O(n^2),像用for循环那种做法。但它也不是完全符合numpy的使用方式。我不太确定在这里怎么更好地利用numpy。

def set_comp(a, b):
   sets_a = set(map(lambda x: frozenset(tuple(x)), a))
   sets_b = set(map(lambda x: frozenset(tuple(x)), b))
   return not sets_a.isdisjoint(sets_b)
3

你也可以通过一些 np.tilenp.swapaxes 的操作来完成这个任务!

def intersect2d(X, Y):
        """
        Function to find intersection of two 2D arrays.
        Returns index of rows in X that are common to Y.
        """
        X = np.tile(X[:,:,None], (1, 1, Y.shape[0]) )
        Y = np.swapaxes(Y[:,:,None], 0, 2)
        Y = np.tile(Y, (X.shape[0], 1, 1))
        eq = np.all(np.equal(X, Y), axis = 1)
        eq = np.any(eq, axis = 1)
        return np.nonzero(eq)[0]

更具体地说,你只需要检查一下返回的数组是否是空的就可以了。

3

根据这里提到的相同思路,你可以这样做:

def make_1d_view(a):
    a = np.ascontiguousarray(a)
    dt = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
    return a.view(dt).ravel()

def f(a, b):
    return len(np.intersect1d(make_1d_view(A), make_1d_view(b))) != 0

>>> f(A, B)
True
>>> f(A, C)
False

不过,这种方法对浮点数类型不太适用(它不会把 +0.0 和 -0.0 视为相同的值),而且 np.intersect1d 使用了排序,所以它的性能是线性对数级别的,而不是线性级别的。如果你想提高性能,可以尝试在你的代码中复制 np.intersect1d 的源代码,然后不检查返回数组的长度,而是对布尔索引数组使用 np.any

4

这样做就可以了:

In [11]:

def f(arrA, arrB):
    return not set(map(tuple, arrA)).isdisjoint(map(tuple, arrB))
In [12]:

f(A, B)
Out[12]:
True
In [13]:

f(A, C)
Out[13]:
False
In [14]:

f(B, C)
Out[14]:
False

要找交集?好吧,set听起来是个不错的选择。
不过numpy.arraylist是不可哈希的?没问题,把它们转换成tuple就行。
这就是思路。

numpy的方法会涉及到一些很难理解的广播操作:

In [34]:

(A[...,np.newaxis]==B[...,np.newaxis].T).all(1)
Out[34]:
array([[False, False],
       [ True, False],
       [False, False]], dtype=bool)
In [36]:

(A[...,np.newaxis]==B[...,np.newaxis].T).all(1).any()
Out[36]:
True

这里有一些时间测试的结果:

In [38]:
#Dan's method
%timeit set_comp(A,B)
10000 loops, best of 3: 34.1 µs per loop
In [39]:
#Avoiding lambda will speed things up
%timeit f(A,B)
10000 loops, best of 3: 23.8 µs per loop
In [40]:
#numpy way probably will be slow, unless the size of the array is very big (my guess)
%timeit (A[...,np.newaxis]==B[...,np.newaxis].T).all(1).any()
10000 loops, best of 3: 49.8 µs per loop

另外,numpy的方法会比较占用内存,因为A[...,np.newaxis]==B[...,np.newaxis].T这一步会创建一个三维数组。

撰写回答