比较不同大小的numpy数组

2024-05-29 03:15:52 发布

您现在位置:Python中文网/ 问答频道 /正文

我有两个xy坐标点阵列:

basic_pts = np.array([[0, 0], [1, 0], [2, 0], [0, 1], [1, 1], [0, 2]])
new_pts = np.array([[2, 2], [2, 1], [0.5, 0.5], [1.5, 0.5]])

因此,我只想从数组new_pts中得到那些满足条件的点,即basic_pts中没有具有更大x和y值的点。所以结果是

^{pr2}$

我有一个可行的解决方案,但由于列表理解的工作,它不适合更大的数据量。在

x_cond = ([basic_pts[:, 0] > x for x in new_pts[:, 1]])
y_cond = ([basic_pts[:, 1] > y for y in new_pts[:, 1]])
xy_cond_ = np.logical_and(x_cond, y_cond)
xy_cond = np.swapaxes(xy_cond_, 0, 1)
mask = np.invert(np.logical_or.reduce(xy_cond))
res_pts = new_pts[mask]

有没有更好的方法来解决这个问题,只使用numpy而不理解列表?在


Tags: in列表newforbasicnpmask数组
2条回答

您可以使用^{}-

# Get xy_cond equivalent after extending dimensions of basic_pts to a 2D array
# version by "pushing" separately col-0 and col-1 to axis=0 and creating a 
# singleton dimension at axis=1. 
# Compare these two extended versions with col-1 of new_pts. 
xyc = (basic_pts[:,0,None] > new_pts[:,1]) & (basic_pts[:,1,None] > new_pts[:,1])

# Create mask equivalent and index into new_pts to get selective rows from it
mask = ~(xyc).any(0)
res_pts_out = new_pts[mask]

正如val所指出的,创建中间len(basic_pts)×len(new_pts)数组的解决方案可能过于占用内存。另一方面,在一个循环中测试new_pts中的每个点的解决方案可能过于耗时。我们可以通过选择一个批量大小k并使用Divakar的解决方案对new_pts进行批量测试来弥补这一差距:

basic_pts = np.array([[0, 0], [1, 0], [2, 0], [0, 1], [1, 1], [0, 2]])
new_pts = np.array([[2, 2], [2, 1], [0.5, 0.5], [1.5, 0.5]])
k = 2
subresults = []
for i in range(0, len(new_pts), k):
    j = min(i + k, len(new_pts))
    # Process new_pts[i:j] using Divakar's solution
    xyc = np.logical_and(
        basic_pts[:, np.newaxis, 0] > new_pts[np.newaxis, i:j, 0],
        basic_pts[:, np.newaxis, 1] > new_pts[np.newaxis, i:j, 1])
    mask = ~(xyc).any(axis=0)
    # mask indicates which points among new_pts[i:j] to use
    subresults.append(new_pts[i:j][mask])
# Concatenate subresult lists
res = np.concatenate(subresults)
print(res)
# Prints:
array([[ 2. ,  2. ],
       [ 2. ,  1. ],
       [ 1.5,  0.5]])

相关问题 更多 >

    热门问题