高效跟踪字典中基于值的前k个键

4 投票
3 回答
2894 浏览
提问于 2025-04-17 19:11

如何高效地跟踪字典中值最大的前k个键,同时字典的键会不断更新呢?

我试过一种简单的方法,就是在每次更新后从字典中创建一个排序列表(就像在获取字典中最大值的键?中描述的那样),但这样做非常耗时,并且不适合大规模使用

现实世界的例子:

比如说,我们需要统计来自无限数据流的单词频率。在任何时刻,程序可能会被要求报告某个单词是否在当前最常见的前k个单词中。我们该如何高效地完成这个任务呢?

collections.Counter太慢了

>>> from itertools import permutations
>>> from collections import Counter
>>> from timeit import timeit
>>> c = Counter()
>>> for x in permutations(xrange(10), 10):
    c[x] += 1


>>> timeit('c.most_common(1)', 'from __main__ import c', number=1)
0.7442058258093311
>>> sum(c.values())
3628800

计算这个值几乎需要一秒钟!

我希望most_common()函数的时间复杂度是O(1)。这应该可以通过使用另一种数据结构来实现,这种结构只存储当前的前k个项目,并跟踪当前的最小值。

3 个回答

1

可以使用 collections.Counter,它已经可以处理这个实际的例子了。你还有其他的使用场景吗?

2

collections.Counter.most_common 这个方法会遍历所有的值,找出第 N 大的值,它是通过把这些值放进一个堆里来实现的(我想,它的时间复杂度是 O(M log N),其中 M 是字典中元素的总数)。

正如评论中 Wei Yen 提到的,heapq 可能也能用:你可以在字典旁边维护一个包含 N 个最大值的 heapq,每当你修改字典时,就检查这个值是否在堆里,或者现在是否应该在堆里。问题是,正如你所说的,接口并没有提供修改已经存在元素的“优先级”的方法(在你的情况下,就是计数的负数,因为这是一个最小堆)。

你可以直接修改相关的元素,然后运行 heapq.heapify 来恢复堆的结构。这个过程需要线性时间来找到相关的元素(N 的大小),除非你做额外的记录来关联元素和位置;这样做可能不太值得。然后还需要另一个线性时间来重新调整堆的结构。如果一个元素之前不在列表中,现在加入了,你需要通过替换最小的元素来将其添加到堆中(这个过程也是线性的,除非有其他结构的帮助)。

不过,heapq 的私有接口中有一个函数 _siftdown,它的注释是:

# 'heap' is a heap at all indices >= startpos, except possibly for pos.  pos
# is the index of a leaf with a possibly out-of-order value.  Restore the
# heap invariant.

这听起来不错!调用 heapq._siftdown(heap, 0, pos_of_relevant_idx) 可以在 O(log N) 的时间内修复堆。当然,你得先找到你要增加的索引的位置,这个过程是线性的。你可以维护一个元素到索引的字典来避免这个问题(同时保持指向最小元素位置的指针),但这样你要么得复制 _siftdown 的源代码并修改它以在交换时更新字典,要么在之后做一次线性时间的遍历来重建字典(但你本来就是想避免线性遍历……)。

小心的话,这样做应该能达到 O(log N) 的时间复杂度。不过,实际上有一种叫做 Fibonacci 堆 的数据结构,它支持你需要的所有操作,并且在(摊销)常数时间内完成。不幸的是,这种情况下大 O 复杂度并不能完全说明问题;Fibonacci 堆的复杂性意味着在实际应用中,除了非常大的堆之外,它们的速度并不比二叉堆快。此外(也许“因此”),我在快速搜索中没有找到标准的 Python 实现,虽然 Boost C++ 库中确实包含了一个。

我建议你先尝试使用 heapq,对你要修改的元素进行线性搜索,然后调用 _siftdown;这个过程是 O(N) 的,相比之下,Counter 的方法是 O(M log N)。如果这样做还是太慢,你可以维护一个额外的索引字典,并自己实现一个更新字典的 _siftdown,这样应该能达到 O(log N) 的时间复杂度。如果这仍然太慢(我对此表示怀疑),你可以寻找一个 Python 的包装器来使用 Boost 的 Fibonacci 堆(或其他实现),但我真的怀疑这样做是否值得。

0

我们可以实现一个类,用来跟踪前k个值,因为我觉得标准库里没有这个功能。这个类会和主要的字典对象(可能是一个Counter)保持同步更新。你也可以把它作为主要字典对象子类的一个属性。

实现方式

class MostCommon(object):
    """Keep track the top-k key-value pairs.

    Attributes:
        top: Integer representing the top-k items to keep track of.
        store: Dictionary of the top-k items.
        min: The current minimum of any top-k item.
        min_set: Set where keys are counts, and values are the set of
            keys with that count.
    """
    def __init__(self, top):
        """Create a new MostCommon object to track key-value paris.

        Args:
            top: Integer representing the top-k values to keep track of.
        """
        self.top = top
        self.store = dict()
        self.min = None
        self.min_set = defaultdict(set)

    def _update_existing(self, key, value):
        """Update an item that is already one of the top-k values."""
        # Currently handle values that are non-decreasing.
        assert value > self.store[key]
        self.min_set[self.store[key]].remove(key)
        if self.store[key] == self.min:  # Previously was the minimum.
            if not self.min_set[self.store[key]]:  # No more minimums.
                del self.min_set[self.store[key]]
                self.min_set[value].add(key)
                self.min = min(self.min_set.keys())
        self.min_set[value].add(key)
        self.store[key] = value

    def __contains__(self, key):
        """Boolean if the key is one of the top-k items."""
        return key in self.store

    def __setitem__(self, key, value):
        """Assign a value to a key.

        The item won't be stored if it is less than the minimum (and
        the store is already full). If the item is already in the store,
        the value will be updated along with the `min` if necessary.
        """
        # Store it if we aren't full yet.
        if len(self.store) < self.top:
            if key in self.store:  # We already have this item.
                self._update_existing(key, value)
            else:  # Brand new item.
                self.store[key] = value
                self.min_set[value].add(key)
                if value < self.min or self.min is None:
                    self.min = value
        else:  # We're full. The value must be greater minimum to be added.
            if value > self.min:  # New item must be larger than current min.
                if key in self.store:  # We already have this item.
                    self._update_existing(key, value)
                else:  # Brand new item.
                    # Make room by removing one of the current minimums.
                    old = self.min_set[self.min].pop()
                    del self.store[old]
                    # Delete the set if there are no old minimums left.
                    if not self.min_set[self.min]:
                        del self.min_set[self.min]
                    # Add the new item.
                    self.min_set[value].add(key)
                    self.store[key] = value
                    self.min = min(self.min_set.keys())

    def __repr__(self):
        if len(self.store) < 10:
            store = repr(self.store)
        else:
            length = len(self.store)
            largest = max(self.store.itervalues())
            store = '<len={length}, max={largest}>'.format(length=length,
                                                           largest=largest)
        return ('{self.__class__.__name__}(top={self.top}, min={self.min}, '
                'store={store})'.format(self=self, store=store))

示例用法

>>> common = MostCommon(2)
>>> common
MostCommon(top=2, min=None, store={})
>>> common['a'] = 1
>>> common
MostCommon(top=2, min=1, store={'a': 1})
>>> 'a' in common
True
>>> common['b'] = 2
>>> common
MostCommon(top=2, min=1, store={'a': 1, 'b': 2})
>>> common['c'] = 3
>>> common
MostCommon(top=2, min=2, store={'c': 3, 'b': 2})
>>> 'a' in common
False
>>> common['b'] = 4
>>> common
MostCommon(top=2, min=3, store={'c': 3, 'b': 4})

更新值后访问确实是O(1)

>>> counter = Counter()
>>> for x in permutations(xrange(10), 10):
        counter[x] += 1

>>> common = MostCommon(1)
>>> for key, value in counter.iteritems():
    common[key] = value

>>> common
MostCommon(top=1, min=1, store={(9, 7, 8, 0, 2, 6, 5, 4, 3, 1): 1})
>>> timeit('repr(common)', 'from __main__ import common', number=1)
1.3251570635475218e-05

访问的时间复杂度是O(1),但是当最小值在设置某个值时发生变化,这个操作的时间复杂度是O(n),其中n是前k个值的数量。不过这仍然比Counter要好,因为Counter在每次访问时的时间复杂度是O(n),而这里的n是整个词汇表的大小!

撰写回答