获取ndarray中N个最高值的索引

8 投票
3 回答
5088 浏览
提问于 2025-04-30 02:48

假设我们有一个大小为100x100x100的直方图,我想找到两个最高的值a和b,以及它们的位置(a1, a2, a3)和(b1, b2, b3),就像这样:

hist[a1][a2][a3] = a
hist[b1][b2][b3] = b

我们可以很容易地用hist.max()来找到最高的值,但我该如何在一个ndarray中找到前X个最高的值呢?

我知道通常使用np.argmax来获取值的位置,但在这种情况下:

hist.argmax().shape = ()  # single value
for i in range(3):
    hist.argmax(i).shape = (100, 100)

我该如何得到一个形状为(3)的元组,每个维度对应一个值呢?

暂无标签

3 个回答

0

我想你可以这样做:

(伪代码)

#work on a copy
working_hist = copy(hist)
greatest = []

min_value = hist.argmin().shape

#while searching for the N greatest values, do N times
for i in range(N):
    #get the current max value
    max_value = hist.argmax().shape
    #save it
    greatest.append(max_value)
    #and then replace it by the minimum value
    hist(max_value.shape)= min_value

我已经很多年没用过numpy了,所以不太确定具体的语法。这个代码只是为了给你一个类似伪代码的答案。

如果你还保留了提取值的位置,就可以避免在一个副本上操作,最后用提取的信息来恢复原来的矩阵。

3

你可以使用 where 函数:

a=np.random.random((100,100,100))
np.where(a==a.max())
(array([46]), array([62]), array([61]))

这样就能把结果放到一个数组里:

np.hstack(np.where(a==a.max()))
array([46, 62, 61])

而且,正如提问者所要求的,我们可以得到一个元组:

tuple(np.hstack(np.where(a==a.max())))
(46, 62, 61)

编辑:

如果你想找到最大的 N 个数的索引,可以使用 heapq 模块里的 nlargest 函数:

N=3
np.where(a>=heapq.nlargest(3,a.flatten())[-1])
(array([46, 62, 61]), array([95, 85, 97]), array([70, 35,  2]))
17

你可以先对数组进行扁平化处理,使用 numpy.argpartition 来获取前 k 个元素的索引。然后,你可以利用 numpy.unravel_index 将这些一维的索引转换成数组的原始形状。

>>> arr = np.arange(100*100*100).reshape(100, 100, 100)
>>> np.random.shuffle(arr)
>>> indices =  np.argpartition(arr.flatten(), -2)[-2:]
>>> np.vstack(np.unravel_index(indices, arr.shape)).T
array([[97, 99, 98],
       [97, 99, 99]])
)
>>> arr[97][99][98]
999998
>>> arr[97][99][99]
999999

撰写回答