Python-2D Numpy数组的交集

2024-04-26 03:33:29 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在拼命寻找一种有效的方法来检查两个2D numpy数组是否相交。

所以我有两个任意数量的二维数组,比如:

A=np.array([[2,3,4],[5,6,7],[8,9,10]])
B=np.array([[5,6,7],[1,3,4]])
C=np.array([[1,2,3],[6,6,7],[10,8,9]])

如果至少有一个向量与另一个数组中的另一个向量相交,我所需要的就是一个True,否则就是一个false。所以它应该给出这样的结果:

f(A,B)  -> True
f(A,C)  -> False

我对python有点陌生,一开始我用python列表编写程序,这很有用,但当然效率很低。这个程序需要几天才能完成,所以我现在正在研究一个numpy.array解决方案,但是这些数组确实不太容易处理。

下面是关于我的程序和Python列表解决方案的一些上下文:

我所做的是一种自我回避的三维随机行走。http://en.wikipedia.org/wiki/Self-avoiding_walk。但是,与其做一次随机的散步,希望它能达到一个理想的长度(例如,我希望链子能由1000颗珠子组成),而不要到达一个死胡同,我要做的是:

我创建了一个具有所需长度N的“扁平”链:

X=[]
for i in range(0,N+1):
    X.append((i,0,0))

现在我把这条扁链折起来:

  1. 随机选择其中一个元素(“pivotelement”)
  2. 随机选择一个方向(数据透视表左侧或右侧的所有元素)
  3. 从9个可能的空间旋转中随机选择一个(3个轴*3个可能的90°、180°、270°)
  4. 用选定的旋转旋转选定方向的所有元素
  5. 检查所选方向的新元素是否与其他方向相交
  6. 无交叉点->;接受新配置,否则->;保留旧链。

步骤1.-6。必须进行大量的操作(例如,对于长度为1000的链,~5000次),因此必须有效地执行这些步骤。我基于列表的解决方案如下:

def PivotFold(chain):
randPiv=random.randint(1,N)  #Chooses a random pivotelement, N is the Chainlength
Pivot=chain[randPiv]  #get that pivotelement
C=[]  #C is going to be a shifted copy of the chain
intersect=False
for j in range (0,N+1):   # Here i shift the hole chain to get the pivotelement to the origin, so i can use simple rotations around the origin
    C.append((chain[j][0]-Pivot[0],chain[j][1]-Pivot[1],chain[j][2]-Pivot[2]))
rotRand=random.randint(1,18)  # rotRand is used to choose a direction and a Rotation (2 possible direction * 9 rotations = 18 possibilitys)
#Rotations around Z-Axis
if rotRand==1:
    for j in range (randPiv,N+1):
        C[j]=(-C[j][1],C[j][0],C[j][2])
        if C[0:randPiv].__contains__(C[j])==True:
            intersect=True
            break
elif rotRand==2:
    for j in range (randPiv,N+1):
        C[j]=(C[j][1],-C[j][0],C[j][2])
        if C[0:randPiv].__contains__(C[j])==True:
            intersect=True
            break
...etc
if intersect==False: # return C if there was no intersection in C
    Shizz=C
else:
    Shizz=chain
return Shizz

PivotFold(chain)函数将在最初的平链X上使用很多次。这本书写得很幼稚,所以也许你有一些改进的地方^我认为Numpyarray会很好,因为我可以有效地移动和旋转整个链,而不必在所有元素上循环。。。


Tags: thetointrue元素chainforif
3条回答

这应该做到:

In [11]:

def f(arrA, arrB):
    return not set(map(tuple, arrA)).isdisjoint(map(tuple, arrB))
In [12]:

f(A, B)
Out[12]:
True
In [13]:

f(A, C)
Out[13]:
False
In [14]:

f(B, C)
Out[14]:
False

找到交叉点?好吧,set听起来是个合乎逻辑的选择。 但是numpy.arraylist不可哈希?好,把它们转换成tuple。 就是这个主意。

一种numpy的方式涉及到非常不可读的板卡:

In [34]:

(A[...,np.newaxis]==B[...,np.newaxis].T).all(1)
Out[34]:
array([[False, False],
       [ True, False],
       [False, False]], dtype=bool)
In [36]:

(A[...,np.newaxis]==B[...,np.newaxis].T).all(1).any()
Out[36]:
True

一些时间结果:

In [38]:
#Dan's method
%timeit set_comp(A,B)
10000 loops, best of 3: 34.1 µs per loop
In [39]:
#Avoiding lambda will speed things up
%timeit f(A,B)
10000 loops, best of 3: 23.8 µs per loop
In [40]:
#numpy way probably will be slow, unless the size of the array is very big (my guess)
%timeit (A[...,np.newaxis]==B[...,np.newaxis].T).all(1).any()
10000 loops, best of 3: 49.8 µs per loop

另外,numpy方法将占用内存,因为A[...,np.newaxis]==B[...,np.newaxis].T步骤将创建一个3D数组。

使用here概述的相同想法,您可以执行以下操作:

def make_1d_view(a):
    a = np.ascontiguousarray(a)
    dt = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
    return a.view(dt).ravel()

def f(a, b):
    return len(np.intersect1d(make_1d_view(A), make_1d_view(b))) != 0

>>> f(A, B)
True
>>> f(A, C)
False

这对浮点类型不起作用(它不会将+0.0和-0.0视为相同的值),并且np.intersect1d使用排序,因此它具有线性而不是线性的性能。通过复制代码中^{}的源代码,并调用布尔索引数组上的np.any,而不是检查返回数组的长度,您可能可以压缩一些性能。

你也可以通过一些np.tilenp.swapaxes业务来完成工作!

def intersect2d(X, Y):
        """
        Function to find intersection of two 2D arrays.
        Returns index of rows in X that are common to Y.
        """
        X = np.tile(X[:,:,None], (1, 1, Y.shape[0]) )
        Y = np.swapaxes(Y[:,:,None], 0, 2)
        Y = np.tile(Y, (X.shape[0], 1, 1))
        eq = np.all(np.equal(X, Y), axis = 1)
        eq = np.any(eq, axis = 1)
        return np.nonzero(eq)[0]

为了更具体地回答这个问题,您只需要检查返回的数组是否为空。

相关问题 更多 >