对Numpy数组求和并去除重复元素

2 投票
2 回答
1561 浏览
提问于 2025-04-21 02:22

我有四个长度相同的一维Numpy数组。

前面三个数组用作ID,唯一标识第四个数组。

这些ID数组中有重复的组合,我需要对第四个数组进行求和,并且要把所有四个数组中重复的元素去掉。

x = np.array([1, 2, 4, 1])
y = np.array([1, 1, 4, 1])
z = np.array([1, 2, 2, 1])
data = np.array([4, 7, 3, 2])

在这种情况下,我需要:

x = [1, 2, 4]
y = [1, 1, 4]
z = [1, 2, 2]
data = [6, 7, 3]

这些数组比较长,所以用循环的方法真的不太行。我相信有一种相对简单的方法可以做到这一点,但我就是想不出来。

2 个回答

2

你可以像reptilicus建议的那样,使用unique和sum来完成以下操作。

from itertools import izip
import numpy as np

x = np.array([1, 2, 4, 1])
y = np.array([1, 1, 4, 1])
z = np.array([1, 2, 2, 1])
data = np.array([4, 7, 3, 2])

# N = len(x)
# ids = x + y*N + z*(N**2)
ids = np.array([hash((a, b, c)) for a, b, c in izip(x, y, z)]) # creates flat ids

_, idx, idx_rep = np.unique(ids, return_index=True, return_inverse=True)

x_out = x[idx]
y_out = y[idx]
z_out = z[idx]
# data_out = np.array([np.sum(data[idx_rep == i]) for i in idx])
data_out = np.bincount(idx_rep, weights=data)

print x_out
print y_out
print z_out
print data_out
4

首先,我们可以把ID向量堆叠成一个矩阵,每个ID作为三列的行:

XYZ = np.vstack((x,y,z)).T

接下来,我们需要找到重复行的索引。不过,np.unique这个函数不能直接处理行,所以我们需要做一些小技巧

order = np.lexsort(XYZ.T)
diff = np.diff(XYZ[order], axis=0)
uniq_mask = np.append(True, (diff != 0).any(axis=1))

这一部分是从np.unique的源代码借来的,它可以找到唯一的索引以及“反向索引”映射:

uniq_inds = order[uniq_mask]
inv_idx = np.zeros_like(order)
inv_idx[order] = np.cumsum(uniq_mask) - 1

最后,对唯一的索引进行求和:

data = np.bincount(inv_idx, weights=data)
x,y,z = XYZ[uniq_inds].T

撰写回答