序列中的n个最大元素(需保留重复项)

8 投票
6 回答
1535 浏览
提问于 2025-04-16 21:22

我需要在一个包含元组的列表中找到最大的n个元素。这里有一个找出前三个元素的例子。

# I have a list of tuples of the form (category-1, category-2, value)
# For each category-1, ***values are already sorted descending by default***
# The list can potentially be approximately a million elements long.
lot = [('a', 'x1', 10), ('a', 'x2', 9), ('a', 'x3', 9), 
       ('a', 'x4',  8), ('a', 'x5', 8), ('a', 'x6', 7),
       ('b', 'x1', 10), ('b', 'x2', 9), ('b', 'x3', 8), 
       ('b', 'x4',  7), ('b', 'x5', 6), ('b', 'x6', 5)]

# This is what I need. 
# A list of tuple with top-3 largest values for each category-1
ans = [('a', 'x1', 10), ('a', 'x2', 9), ('a', 'x3', 9), 
       ('a', 'x4', 8), ('a', 'x5', 8),
       ('b', 'x1', 10), ('b', 'x2', 9), ('b', 'x3', 8)]

我试过用 heapq.nlargest 这个方法。不过它只会返回前三个最大的元素,而且不会返回重复的元素。例如,

heapq.nlargest(3, [10, 10, 10, 9, 8, 8, 7, 6])
# returns
[10, 10, 10]
# I need
[10, 10, 10, 9, 8, 8]

我只能想到一种暴力破解的方法。这是我现在的做法,它是有效的。

res, prev_t, count = [lot[0]], lot[0], 1
for t in lot[1:]:
    if t[0] == prev_t[0]:
        count = count + 1 if t[2] != prev_t[2] else count
        if count <= 3:
            res.append(t)   
    else:
        count = 1
        res.append(t)
    prev_t = t

print res

有没有其他的想法可以实现这个功能呢?

补充说明:对于一个包含100万个元素的列表,使用 timeit 测试的结果显示,mhyfritz的解决方案 运行时间是暴力破解方法的三分之一。我不想让问题变得太长,所以在我的回答中添加了更多细节。

6 个回答

1

一些额外的细节……我对比了两种方法的运行时间,一种是mhyfritz的优秀方案,它使用了itertools,另一种是我自己的代码(暴力破解法)。

下面是timeit的测试结果,测试的条件是n = 10,并且列表中有100万个元素。

# Here's how I built the sample list of 1 million entries.
lot = []
for i in range(1001):
    for j in reversed(range(333)):
        for k in range(3):
            lot.append((i, 'x', j))

# timeit Results for n = 10
brute_force = 6.55s
itertools = 2.07s
# clearly the itertools solution provided by mhyfritz is much faster.

如果有人感兴趣,这里有他代码的运行过程跟踪。

+ Outer loop - x, g1
| a [('a', 'x1', 10), ('a', 'x2', 9), ('a', 'x3', 9), ('a', 'x4', 8), ('a', 'x5', 8), ('a', 'x6', 7)]
+-- Inner loop - y, g2
  |- 10 [('a', 'x1', 10)]
  |- 9 [('a', 'x2', 9), ('a', 'x3', 9)]
  |- 8 [('a', 'x4', 8), ('a', 'x5', 8)]
+ Outer loop - x, g1
| b [('b', 'x1', 10), ('b', 'x2', 9), ('b', 'x3', 8), ('b', 'x4', 7), ('b', 'x5', 6), ('b', 'x6', 5)]
+-- Inner loop - y, g2
  |- 10 [('b', 'x1', 10)]
  |- 9 [('b', 'x2', 9)]
  |- 8 [('b', 'x3', 8)]
2

如果你已经把输入的数据按那种方式排好了,那么你的解决方案可能会比基于heapq的方案要好一些。

你的算法复杂度是O(n),而基于heapq的方案在概念上是O(n * log(3)),而且它可能需要对数据进行更多次的处理才能把它整理好。

7

从你的代码片段来看,lot 是根据 category-1 来分组的。那么下面的代码应该可以正常工作:

from itertools import groupby, islice
from operator import itemgetter

ans = []
for x, g1 in groupby(lot, itemgetter(0)):
    for y, g2 in islice(groupby(g1, itemgetter(2)), 0, 3):
        ans.extend(list(g2))

print ans
# [('a', 'x1', 10), ('a', 'x2', 9), ('a', 'x3', 9), ('a', 'x4', 8), ('a', 'x5', 8),
#  ('b', 'x1', 10), ('b', 'x2', 9), ('b', 'x3', 8)]

撰写回答