序列中的n个最大元素(需保留重复项)
我需要在一个包含元组的列表中找到最大的n个元素。这里有一个找出前三个元素的例子。
# I have a list of tuples of the form (category-1, category-2, value)
# For each category-1, ***values are already sorted descending by default***
# The list can potentially be approximately a million elements long.
lot = [('a', 'x1', 10), ('a', 'x2', 9), ('a', 'x3', 9),
('a', 'x4', 8), ('a', 'x5', 8), ('a', 'x6', 7),
('b', 'x1', 10), ('b', 'x2', 9), ('b', 'x3', 8),
('b', 'x4', 7), ('b', 'x5', 6), ('b', 'x6', 5)]
# This is what I need.
# A list of tuple with top-3 largest values for each category-1
ans = [('a', 'x1', 10), ('a', 'x2', 9), ('a', 'x3', 9),
('a', 'x4', 8), ('a', 'x5', 8),
('b', 'x1', 10), ('b', 'x2', 9), ('b', 'x3', 8)]
我试过用 heapq.nlargest
这个方法。不过它只会返回前三个最大的元素,而且不会返回重复的元素。例如,
heapq.nlargest(3, [10, 10, 10, 9, 8, 8, 7, 6])
# returns
[10, 10, 10]
# I need
[10, 10, 10, 9, 8, 8]
我只能想到一种暴力破解的方法。这是我现在的做法,它是有效的。
res, prev_t, count = [lot[0]], lot[0], 1
for t in lot[1:]:
if t[0] == prev_t[0]:
count = count + 1 if t[2] != prev_t[2] else count
if count <= 3:
res.append(t)
else:
count = 1
res.append(t)
prev_t = t
print res
有没有其他的想法可以实现这个功能呢?
补充说明:对于一个包含100万个元素的列表,使用 timeit
测试的结果显示,mhyfritz的解决方案 运行时间是暴力破解方法的三分之一。我不想让问题变得太长,所以在我的回答中添加了更多细节。
6 个回答
1
一些额外的细节……我对比了两种方法的运行时间,一种是mhyfritz的优秀方案,它使用了itertools
,另一种是我自己的代码(暴力破解法)。
下面是timeit
的测试结果,测试的条件是n = 10
,并且列表中有100万个元素。
# Here's how I built the sample list of 1 million entries.
lot = []
for i in range(1001):
for j in reversed(range(333)):
for k in range(3):
lot.append((i, 'x', j))
# timeit Results for n = 10
brute_force = 6.55s
itertools = 2.07s
# clearly the itertools solution provided by mhyfritz is much faster.
如果有人感兴趣,这里有他代码的运行过程跟踪。
+ Outer loop - x, g1
| a [('a', 'x1', 10), ('a', 'x2', 9), ('a', 'x3', 9), ('a', 'x4', 8), ('a', 'x5', 8), ('a', 'x6', 7)]
+-- Inner loop - y, g2
|- 10 [('a', 'x1', 10)]
|- 9 [('a', 'x2', 9), ('a', 'x3', 9)]
|- 8 [('a', 'x4', 8), ('a', 'x5', 8)]
+ Outer loop - x, g1
| b [('b', 'x1', 10), ('b', 'x2', 9), ('b', 'x3', 8), ('b', 'x4', 7), ('b', 'x5', 6), ('b', 'x6', 5)]
+-- Inner loop - y, g2
|- 10 [('b', 'x1', 10)]
|- 9 [('b', 'x2', 9)]
|- 8 [('b', 'x3', 8)]
2
如果你已经把输入的数据按那种方式排好了,那么你的解决方案可能会比基于heapq的方案要好一些。
你的算法复杂度是O(n),而基于heapq的方案在概念上是O(n * log(3)),而且它可能需要对数据进行更多次的处理才能把它整理好。
7
从你的代码片段来看,lot
是根据 category-1 来分组的。那么下面的代码应该可以正常工作:
from itertools import groupby, islice
from operator import itemgetter
ans = []
for x, g1 in groupby(lot, itemgetter(0)):
for y, g2 in islice(groupby(g1, itemgetter(2)), 0, 3):
ans.extend(list(g2))
print ans
# [('a', 'x1', 10), ('a', 'x2', 9), ('a', 'x3', 9), ('a', 'x4', 8), ('a', 'x5', 8),
# ('b', 'x1', 10), ('b', 'x2', 9), ('b', 'x3', 8)]