Python - 2D Numpy数组的交集
我正在拼命寻找一种高效的方法来检查两个二维的numpy数组是否有交集。
我有两个数组,每个数组里面有任意数量的二维数组,像这样:
A=np.array([[2,3,4],[5,6,7],[8,9,10]])
B=np.array([[5,6,7],[1,3,4]])
C=np.array([[1,2,3],[6,6,7],[10,8,9]])
我只需要判断如果有至少一个向量和另一个数组中的向量相交,就返回True,否则返回False。所以结果应该是这样的:
f(A,B) -> True
f(A,C) -> False
我对Python还不是很熟悉,最开始我用Python的列表写了我的程序,虽然能用,但效率非常低。程序运行起来要花几天时间,所以我现在在尝试用numpy.array
来解决这个问题,但这些数组真的不太好处理。
这里是我程序的一些背景和用Python列表的解决方案:
我做的事情有点像在三维空间中进行自避随机游走。http://en.wikipedia.org/wiki/Self-avoiding_walk。但我不是随便走动,希望能达到一个理想的长度(比如我想要由1000个珠子组成的链),而是这样做:
我创建一个“平坦”的链,长度为N:
X=[]
for i in range(0,N+1):
X.append((i,0,0))
然后我把这个平坦的链折叠起来:
- 随机选择一个元素(“枢轴元素”)
- 随机选择一个方向(要么是枢轴元素左边的所有元素,要么是右边的)
- 随机选择一个空间中的旋转方式(3个轴 * 3种可能的旋转:90°、180°、270°)
- 用选择的旋转方式旋转所选方向的所有元素
- 检查新选择的方向的元素是否与另一方向相交
- 如果没有交集 -> 接受新的配置,否则 -> 保留旧的链。
步骤1到6需要重复很多次(例如,对于长度为1000的链,大约需要5000次),所以这些步骤必须高效完成。我基于列表的解决方案如下:
def PivotFold(chain):
randPiv=random.randint(1,N) #Chooses a random pivotelement, N is the Chainlength
Pivot=chain[randPiv] #get that pivotelement
C=[] #C is going to be a shifted copy of the chain
intersect=False
for j in range (0,N+1): # Here i shift the hole chain to get the pivotelement to the origin, so i can use simple rotations around the origin
C.append((chain[j][0]-Pivot[0],chain[j][1]-Pivot[1],chain[j][2]-Pivot[2]))
rotRand=random.randint(1,18) # rotRand is used to choose a direction and a Rotation (2 possible direction * 9 rotations = 18 possibilitys)
#Rotations around Z-Axis
if rotRand==1:
for j in range (randPiv,N+1):
C[j]=(-C[j][1],C[j][0],C[j][2])
if C[0:randPiv].__contains__(C[j])==True:
intersect=True
break
elif rotRand==2:
for j in range (randPiv,N+1):
C[j]=(C[j][1],-C[j][0],C[j][2])
if C[0:randPiv].__contains__(C[j])==True:
intersect=True
break
...etc
if intersect==False: # return C if there was no intersection in C
Shizz=C
else:
Shizz=chain
return Shizz
函数PivotFold(chain)会在最初的平坦链X上使用很多次。这个函数写得比较简单,所以也许你有一些改进的建议^^ 我觉得使用numpy数组会很好,因为我可以高效地移动和旋转整个链,而不需要遍历所有元素……
6 个回答
我觉得你想要的是,如果两个数组有相同的子数组,就返回真!你可以使用这个:
def(A,B):
for i in A:
for j in B:
if i==j
return True
return False
这个方法应该会快很多,因为它的复杂度不是O(n^2),像用for循环那种做法。但它也不是完全符合numpy的使用方式。我不太确定在这里怎么更好地利用numpy。
def set_comp(a, b):
sets_a = set(map(lambda x: frozenset(tuple(x)), a))
sets_b = set(map(lambda x: frozenset(tuple(x)), b))
return not sets_a.isdisjoint(sets_b)
你也可以通过一些 np.tile
和 np.swapaxes
的操作来完成这个任务!
def intersect2d(X, Y):
"""
Function to find intersection of two 2D arrays.
Returns index of rows in X that are common to Y.
"""
X = np.tile(X[:,:,None], (1, 1, Y.shape[0]) )
Y = np.swapaxes(Y[:,:,None], 0, 2)
Y = np.tile(Y, (X.shape[0], 1, 1))
eq = np.all(np.equal(X, Y), axis = 1)
eq = np.any(eq, axis = 1)
return np.nonzero(eq)[0]
更具体地说,你只需要检查一下返回的数组是否是空的就可以了。
根据这里提到的相同思路,你可以这样做:
def make_1d_view(a):
a = np.ascontiguousarray(a)
dt = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
return a.view(dt).ravel()
def f(a, b):
return len(np.intersect1d(make_1d_view(A), make_1d_view(b))) != 0
>>> f(A, B)
True
>>> f(A, C)
False
不过,这种方法对浮点数类型不太适用(它不会把 +0.0 和 -0.0 视为相同的值),而且 np.intersect1d
使用了排序,所以它的性能是线性对数级别的,而不是线性级别的。如果你想提高性能,可以尝试在你的代码中复制 np.intersect1d
的源代码,然后不检查返回数组的长度,而是对布尔索引数组使用 np.any
。
这样做就可以了:
In [11]:
def f(arrA, arrB):
return not set(map(tuple, arrA)).isdisjoint(map(tuple, arrB))
In [12]:
f(A, B)
Out[12]:
True
In [13]:
f(A, C)
Out[13]:
False
In [14]:
f(B, C)
Out[14]:
False
要找交集?好吧,set
听起来是个不错的选择。
不过numpy.array
或list
是不可哈希的?没问题,把它们转换成tuple
就行。
这就是思路。
用numpy
的方法会涉及到一些很难理解的广播操作:
In [34]:
(A[...,np.newaxis]==B[...,np.newaxis].T).all(1)
Out[34]:
array([[False, False],
[ True, False],
[False, False]], dtype=bool)
In [36]:
(A[...,np.newaxis]==B[...,np.newaxis].T).all(1).any()
Out[36]:
True
这里有一些时间测试的结果:
In [38]:
#Dan's method
%timeit set_comp(A,B)
10000 loops, best of 3: 34.1 µs per loop
In [39]:
#Avoiding lambda will speed things up
%timeit f(A,B)
10000 loops, best of 3: 23.8 µs per loop
In [40]:
#numpy way probably will be slow, unless the size of the array is very big (my guess)
%timeit (A[...,np.newaxis]==B[...,np.newaxis].T).all(1).any()
10000 loops, best of 3: 49.8 µs per loop
另外,numpy
的方法会比较占用内存,因为A[...,np.newaxis]==B[...,np.newaxis].T
这一步会创建一个三维数组。