Python中快速获取列表中N个最小或最大元素的方法
我现在有一个很长的列表,我用一个叫做lambda函数f的东西来对它进行排序。然后,我从前五个元素中随机选择一个。大概是这样的:
f = lambda x: some_function_of(x, local_variable)
my_list.sort(key=f)
foo = choice(my_list[:4])
根据性能分析工具的反馈,这个过程在我的程序中是个瓶颈。我该怎么加快速度呢?有没有什么快速的方法可以直接获取我想要的元素(理论上不需要对整个列表进行排序)。谢谢。
2 个回答
1
其实在平均情况下,这个操作可以在线性时间内完成,也就是O(N)。
你需要一个叫做“划分算法”的东西:
def partition(seq, pred, start=0, end=-1):
if end == -1: end = len(seq)
while True:
while True:
if start == end: return start
if not pred(seq[start]): break
start += 1
while True:
if pred(seq[end-1]): break
end -= 1
if start == end: return start
seq[start], seq[end-1] = seq[end-1], seq[start]
start += 1
end -= 1
这个算法可以被一个叫做“nth_element”的算法使用:
def nth_element(seq_in, n, key=lambda x:x):
start, end = 0, len(seq_in)
seq = [(x, key(x)) for x in seq_in]
def partition_pred(x): return x[1] < seq[end-1][1]
while start != end:
pivot = (end + start) // 2
seq[pivot], seq[end - 1] = seq[end - 1], seq[pivot]
pivot = partition(seq, partition_pred, start, end)
seq[pivot], seq[end - 1] = seq[end - 1], seq[pivot]
if pivot == n: break
if pivot < n: start = pivot + 1
else: end = pivot
seq_in[:] = (x for x, k in seq)
有了这些,你只需要把你第二行的(排序)代码替换成:
nth_element(my_list, 4, key=f)
11
可以使用 heapq.nlargest
或者 heapq.nsmallest
这个工具。
举个例子:
import heapq
elements = heapq.nsmallest(4, my_list, key=f)
foo = choice(elements)
这个方法的运行时间是 O(N+KlogN)(这里的 K 是你想要返回的元素数量,N 是列表的大小),当 K 相对于 N 较小时,这个速度比普通排序的 O(NlogN) 要快。