根据一列中的公共值从两个或更多二维numpy数组创建交集

4 投票
1 回答
1486 浏览
提问于 2025-04-17 10:54

我有三个numpy的记录数组,它们的结构如下。第一列是一些位置(整数),第二列是分数(浮点数)。

输入:

a = [[1, 5.41],
     [2, 5.42],
     [3, 12.32],
     dtype=[('position', '<i4'), ('score', '<f4')])
     ]

b = [[3, 8.41],
     [6, 7.42],
     [4, 6.32],
     dtype=[('position', '<i4'), ('score', '<f4')])
     ]

c = [[3, 7.41],
     [7, 6.42],
     [1, 5.32],
     dtype=[('position', '<i4'), ('score', '<f4')])
     ]

这三个数组的元素数量是一样的。
我想找一个高效的方法,把这三个二维数组根据位置列合并成一个数组。

上面例子的输出数组应该是这样的:

输出:

output = [[3, 12.32, 8.41, 7.41],
          dtype=[('position', '<i4'), ('score1', '<f4'),('score2', '<f4'),('score3', '<f4')])]

输出数组中只有位置为3的那一行,因为这个位置在所有三个输入数组中都有出现。

更新:我简单的做法会是以下几个步骤:

  1. 创建一个包含我三个输入数组第一列的向量。
  2. 使用intersect1D来找出这三个向量的交集。
  3. 以某种方式获取这三个输入数组中对应的索引。
  4. 从这三个输入数组中筛选出行,创建一个新的数组。

更新2:每个位置值可能出现在一个、两个或三个输入数组中。在我的输出数组中,我只想包含那些在所有三个输入数组中都出现的行。

1 个回答

3

这里有一种方法,我觉得应该比较快。首先,你需要做的就是统计每个位置出现的次数。这个函数可以帮你完成这个任务:

def count_positions(positions):
    positions = np.sort(positions)
    diff = np.ones(len(positions), 'bool')
    diff[:-1] = positions[1:] != positions[:-1]
    count = diff.nonzero()[0]
    count[1:] = count[1:] - count[:-1]
    count[0] += 1
    uniqPositions = positions[diff]
    return uniqPositions, count

接下来,使用上面的函数,你只需要找出出现3次的位置:

positions = np.concatenate((a['position'], b['position'], c['position']))
uinqPos, count = count_positions(positions)
uinqPos = uinqPos[count == 3]

我们将使用搜索排序的方法,所以我们需要对a、b和c进行排序:

a.sort(order='position')
b.sort(order='position')
c.sort(order='position')

现在我们可以使用搜索排序来找到每个数组中我们想要的uniqPos的位置:

new_array = np.empty((len(uinqPos), 4))
new_array[:, 0] = uinqPos
index = a['position'].searchsorted(uinqPos)
new_array[:, 1] = a['score'][index]
index = b['position'].searchsorted(uinqPos)
new_array[:, 2] = b['score'][index]
index = c['position'].searchsorted(uinqPos)
new_array[:, 3] = c['score'][index]

可能还有更优雅的解决方案,比如使用字典,但我首先想到的是这个方法,所以就留给其他人去探索吧。

撰写回答