使用NumPy高效获取矩阵中的最小/最大n个值及其索引

18 投票
3 回答
21177 浏览
提问于 2025-04-16 16:31

给定一个NumPy矩阵(二维数组),有什么高效的方法可以返回这个数组中最小或最大的n个值(以及它们的索引)呢?

目前我有:

def n_max(arr, n):
    res = [(0,(0,0))]*n
    for y in xrange(len(arr)):
        for x in xrange(len(arr[y])):
            val = float(arr[y,x])
            el = (val,(y,x))
            i = bisect.bisect(res, el)
            if i > 0:
                res.insert(i, el)
                del res[0]
    return res

这个方法的运行时间是我用pyopencv生成我想要处理的数组的图像模板匹配算法的三倍,我觉得这样太浪费时间了。

3 个回答

0

我刚遇到了一模一样的问题,并且解决了它。
这里是我的解决方案,主要是用到了np.argpartition:

  • 可以应用于任意的轴。
  • 当K远小于数组在这个轴上的大小时,速度很快,复杂度是o(N)。
  • 返回排序后的结果和原始矩阵中对应的索引。
def get_sorted_smallest_K(array, K, axis=-1):
    # Find the least K values of array along the given axis. 
    # Only efficient when K << array.shape[axis].
    # Return:
    #   top_sorted_scores: np.array. The least K values.
    #   top_sorted_indexs: np.array. The least K indexs of original input array.
    
    partition_index = np.take(np.argpartition(array, K, axis), range(0, K), axis)
    top_scores = np.take_along_axis(array, partition_index, axis)
    sorted_index = np.argsort(top_scores, axis=axis)
    top_sorted_scores = np.take_along_axis(top_scores, sorted_index, axis)
    top_sorted_indexs = np.take_along_axis(partition_index, sorted_index, axis)
    return top_sorted_scores, top_sorted_indexs
9

因为NumPy里没有堆这种结构,所以你可以考虑把整个数组排序,然后取最后的 n 个元素。

def n_max(arr, n):
    indices = arr.ravel().argsort()[-n:]
    indices = (numpy.unravel_index(i, arr.shape) for i in indices)
    return [(arr[i], i) for i in indices]

(这样做可能会返回一个和你实现的顺序相反的列表——我没有确认过。)

如果你使用的是更新版本的NumPy,还有一个更有效的解决方案,可以参考 这个回答

26

自从之前的回答以来,NumPy 增加了 numpy.partitionnumpy.argpartition 这两个函数,用于部分排序。这样你可以在 O(arr.size) 的时间内完成,或者如果你需要元素按顺序排列,则是 O(arr.size+n*log(n))

numpy.partition(arr, n) 会返回一个和 arr 大小相同的数组,其中第 n 个元素就像如果这个数组被排序后那样。所有比这个元素小的都会在它前面,所有比它大的都会在后面。

numpy.argpartitionnumpy.partition 的关系,就像 numpy.argsortnumpy.sort 的关系。

下面是如何使用这些函数来找到一个二维数组 arr 中最小的 n 个元素的索引:

flat_indices = numpy.argpartition(arr.ravel(), n-1)[:n]
row_indices, col_indices = numpy.unravel_index(flat_indices, arr.shape)

如果你需要按顺序得到这些索引,比如说 row_indices[0] 是最小元素所在的行,而不仅仅是 n 个最小元素中的一个:

min_elements = arr[row_indices, col_indices]
min_elements_order = numpy.argsort(min_elements)
row_indices, col_indices = row_indices[min_elements_order], col_indices[min_elements_order]

一维的情况要简单得多:

# Unordered:
indices = numpy.argpartition(arr, n-1)[:n]

# Extra code if you need the indices in order:
min_elements = arr[indices]
min_elements_order = numpy.argsort(min_elements)
ordered_indices = indices[min_elements_order]

撰写回答