Python中快速获取列表中N个最小或最大元素的方法

6 投票
2 回答
6254 浏览
提问于 2025-04-15 19:26

我现在有一个很长的列表,我用一个叫做lambda函数f的东西来对它进行排序。然后,我从前五个元素中随机选择一个。大概是这样的:

f = lambda x: some_function_of(x, local_variable)
my_list.sort(key=f)
foo = choice(my_list[:4])

根据性能分析工具的反馈,这个过程在我的程序中是个瓶颈。我该怎么加快速度呢?有没有什么快速的方法可以直接获取我想要的元素(理论上不需要对整个列表进行排序)。谢谢。

2 个回答

1

其实在平均情况下,这个操作可以在线性时间内完成,也就是O(N)。

你需要一个叫做“划分算法”的东西:

def partition(seq, pred, start=0, end=-1):
    if end == -1: end = len(seq)
    while True:
        while True:
            if start == end: return start
            if not pred(seq[start]): break
            start += 1
        while True:
            if pred(seq[end-1]): break
            end -= 1
            if start == end: return start
        seq[start], seq[end-1] = seq[end-1], seq[start]
        start += 1
        end -= 1

这个算法可以被一个叫做“nth_element”的算法使用:

def nth_element(seq_in, n, key=lambda x:x):
    start, end = 0, len(seq_in)
    seq = [(x, key(x)) for x in seq_in]

    def partition_pred(x): return x[1] < seq[end-1][1]

    while start != end:
        pivot = (end + start) // 2
        seq[pivot], seq[end - 1] = seq[end - 1], seq[pivot]
        pivot = partition(seq, partition_pred, start, end)
        seq[pivot], seq[end - 1] = seq[end - 1], seq[pivot]
        if pivot == n: break
        if pivot < n: start = pivot + 1
        else: end = pivot

    seq_in[:] = (x for x, k in seq)

有了这些,你只需要把你第二行的(排序)代码替换成:

nth_element(my_list, 4, key=f)
11

可以使用 heapq.nlargest 或者 heapq.nsmallest 这个工具。

举个例子:

import heapq

elements = heapq.nsmallest(4, my_list, key=f)
foo = choice(elements)

这个方法的运行时间是 O(N+KlogN)(这里的 K 是你想要返回的元素数量,N 是列表的大小),当 K 相对于 N 较小时,这个速度比普通排序的 O(NlogN) 要快。

撰写回答