获取ndarray中N个最高值的索引
假设我们有一个大小为100x100x100的直方图,我想找到两个最高的值a和b,以及它们的位置(a1, a2, a3)和(b1, b2, b3),就像这样:
hist[a1][a2][a3] = a
hist[b1][b2][b3] = b
我们可以很容易地用hist.max()来找到最高的值,但我该如何在一个ndarray中找到前X个最高的值呢?
我知道通常使用np.argmax来获取值的位置,但在这种情况下:
hist.argmax().shape = () # single value
for i in range(3):
hist.argmax(i).shape = (100, 100)
我该如何得到一个形状为(3)的元组,每个维度对应一个值呢?
3 个回答
0
我想你可以这样做:
(伪代码)
#work on a copy
working_hist = copy(hist)
greatest = []
min_value = hist.argmin().shape
#while searching for the N greatest values, do N times
for i in range(N):
#get the current max value
max_value = hist.argmax().shape
#save it
greatest.append(max_value)
#and then replace it by the minimum value
hist(max_value.shape)= min_value
我已经很多年没用过numpy了,所以不太确定具体的语法。这个代码只是为了给你一个类似伪代码的答案。
如果你还保留了提取值的位置,就可以避免在一个副本上操作,最后用提取的信息来恢复原来的矩阵。
3
你可以使用 where 函数:
a=np.random.random((100,100,100))
np.where(a==a.max())
(array([46]), array([62]), array([61]))
这样就能把结果放到一个数组里:
np.hstack(np.where(a==a.max()))
array([46, 62, 61])
而且,正如提问者所要求的,我们可以得到一个元组:
tuple(np.hstack(np.where(a==a.max())))
(46, 62, 61)
编辑:
如果你想找到最大的 N
个数的索引,可以使用 heapq
模块里的 nlargest
函数:
N=3
np.where(a>=heapq.nlargest(3,a.flatten())[-1])
(array([46, 62, 61]), array([95, 85, 97]), array([70, 35, 2]))
17
你可以先对数组进行扁平化处理,使用 numpy.argpartition
来获取前 k
个元素的索引。然后,你可以利用 numpy.unravel_index
将这些一维的索引转换成数组的原始形状。
>>> arr = np.arange(100*100*100).reshape(100, 100, 100)
>>> np.random.shuffle(arr)
>>> indices = np.argpartition(arr.flatten(), -2)[-2:]
>>> np.vstack(np.unravel_index(indices, arr.shape)).T
array([[97, 99, 98],
[97, 99, 99]])
)
>>> arr[97][99][98]
999998
>>> arr[97][99][99]
999999