根据一列中的公共值从两个或更多二维numpy数组创建交集
我有三个numpy的记录数组,它们的结构如下。第一列是一些位置(整数),第二列是分数(浮点数)。
输入:
a = [[1, 5.41],
[2, 5.42],
[3, 12.32],
dtype=[('position', '<i4'), ('score', '<f4')])
]
b = [[3, 8.41],
[6, 7.42],
[4, 6.32],
dtype=[('position', '<i4'), ('score', '<f4')])
]
c = [[3, 7.41],
[7, 6.42],
[1, 5.32],
dtype=[('position', '<i4'), ('score', '<f4')])
]
这三个数组的元素数量是一样的。
我想找一个高效的方法,把这三个二维数组根据位置列合并成一个数组。
上面例子的输出数组应该是这样的:
输出:
output = [[3, 12.32, 8.41, 7.41],
dtype=[('position', '<i4'), ('score1', '<f4'),('score2', '<f4'),('score3', '<f4')])]
输出数组中只有位置为3的那一行,因为这个位置在所有三个输入数组中都有出现。
更新:我简单的做法会是以下几个步骤:
- 创建一个包含我三个输入数组第一列的向量。
- 使用intersect1D来找出这三个向量的交集。
- 以某种方式获取这三个输入数组中对应的索引。
- 从这三个输入数组中筛选出行,创建一个新的数组。
更新2:每个位置值可能出现在一个、两个或三个输入数组中。在我的输出数组中,我只想包含那些在所有三个输入数组中都出现的行。
1 个回答
3
这里有一种方法,我觉得应该比较快。首先,你需要做的就是统计每个位置出现的次数。这个函数可以帮你完成这个任务:
def count_positions(positions):
positions = np.sort(positions)
diff = np.ones(len(positions), 'bool')
diff[:-1] = positions[1:] != positions[:-1]
count = diff.nonzero()[0]
count[1:] = count[1:] - count[:-1]
count[0] += 1
uniqPositions = positions[diff]
return uniqPositions, count
接下来,使用上面的函数,你只需要找出出现3次的位置:
positions = np.concatenate((a['position'], b['position'], c['position']))
uinqPos, count = count_positions(positions)
uinqPos = uinqPos[count == 3]
我们将使用搜索排序的方法,所以我们需要对a、b和c进行排序:
a.sort(order='position')
b.sort(order='position')
c.sort(order='position')
现在我们可以使用搜索排序来找到每个数组中我们想要的uniqPos的位置:
new_array = np.empty((len(uinqPos), 4))
new_array[:, 0] = uinqPos
index = a['position'].searchsorted(uinqPos)
new_array[:, 1] = a['score'][index]
index = b['position'].searchsorted(uinqPos)
new_array[:, 2] = b['score'][index]
index = c['position'].searchsorted(uinqPos)
new_array[:, 3] = c['score'][index]
可能还有更优雅的解决方案,比如使用字典,但我首先想到的是这个方法,所以就留给其他人去探索吧。