寻找“最佳”完全子图

3 投票
5 回答
524 浏览
提问于 2025-04-16 00:31

在优化我一个应用的性能时,我发现了几行(Python)代码中存在一个巨大的性能瓶颈。

我有N个代币,每个代币都有一个值。有些代币之间是矛盾的(比如代币8和代币12不能“共存”)。我的任务是找到k个最佳的代币组合。一个代币组合的值就是组合中所有代币值的总和。

我用的简单算法(我已经实现了...):

  1. 找出所有2^N个代币组合的排列
  2. 剔除那些有矛盾的代币组合
  3. 计算所有剩下的代币组合的值
  4. 根据值对代币组合进行排序
  5. 选择前K个代币组合

在实际情况中,我需要从20个代币中找出前10个代币组合(我计算了1,000,000种排列!),最终缩小到3500个没有矛盾的代币组合。这在我的笔记本上花了5秒钟...

我相信我可以通过只生成没有矛盾的代币组合来优化第1步和第2步。

我也很确定我可以用某种方法在一次搜索中找到最佳的代币组合,并找到一种按值递减遍历代币组合的方法,从而找到我想要的10个最佳组合....

我实际的代码:

all_possibilities = sum((list(itertools.combinations(token_list, i)) for i in xrange(len(token_list)+1)), [])
all_possibilities = [list(option) for option in all_possibilities if self._no_contradiction(option)] 
all_possibilities = [(option, self._probability(option)) for option in all_possibilities]
all_possibilities.sort(key = lambda result: -result[1]) # sort by descending probability

请帮帮我?

Tal.

5 个回答

3

一个 O(n (log n))O(n + m) 的解决方案,适用于 n 个标记和字符串长度 m

你的问题和 NP 完全的团问题的不同之处在于,你的“冲突”图有一定的 结构 - 也就是说,它可以被投影到一维(可以排序)。

这意味着你可以使用分治法;毕竟,不重叠的范围互不影响,所以不需要探索所有可能的状态。特别地,动态规划的方法是可行的。

算法的基本思路

  1. 假设一个标记的位置用 [start, end) 表示(即包含开始,不包含结束)。先按标记的结束位置对标记列表进行排序,我们会逐个遍历它们。
  2. 你将扩展这些标记的子集。这些标记的集合会有一个结束位置(如果某个标记的开始位置在子集的结束位置之前,就不能加入这个子集),还有一个累积值。子集的结束位置是所有标记结束位置中的最大值。
  3. 你需要维护一个映射(比如通过哈希表或数组),将已处理的标记索引映射到当前最佳的非冲突标记子集。这意味着,存储在映射中索引 J 的最佳子集只能包含索引小于或等于 J 的标记。
  4. 在每一步中,你将计算某个位置 J 的最佳子集,然后可能会发生三种情况:你可能已经在映射中缓存了这个计算(简单),或者最佳子集包含标记 J,或者最佳子集不包含标记 J。如果没有缓存,你只能通过尝试这两种选项来判断最佳子集是否包含或不包含 J。

现在,关键在于缓存 - 你需要尝试这两种选项,这看起来像是递归(指数级)搜索,但其实不一定是。

  • 如果索引 J 的最佳子集 包含 token[J],那么它就不能包含与这个标记重叠的任何标记 - 特别是,由于我们是按 token.end 排序的,列表中会有一个最后的标记 K,满足 K < Jtoken[K].end <= token[J].start:对于这个标记 K,我们也可以计算最佳子集(或者我们可能已经缓存了它)。
  • 另一方面,它可能 不包含 token[J],那么最佳子集就是 token[J-1]
  • 无论哪种情况,一个特殊的标记 token[-1],其 token[-1].end = 0 和子集值 0 可以作为基本情况。

因为你只需要对每个标记索引进行一次计算,所以这一部分实际上是线性的。然而,简单地对标记进行排序(我推荐这样做)是 O(n log(n)),而根据字符串位置找到最后一个标记索引是 O(log(n)) - 重复 n 次;所以总体运行时间是 O(n log(n))。你可以通过观察到不需要对任意列表进行排序来将其减少到 O(n) - 最大的字符串位置是有限且较小的,因此可以通过在字符串中索引来进行排序,但这几乎肯定不值得。同样,虽然通过二分查找找到一个标记是 log n,你可以通过对齐两个列表来实现 - 一个按标记结束排序,另一个按标记开始排序 - 从而允许 O(n + m) 的实现。除非 n 真的很大,否则这样做不值得。

如果你从字符串的前面遍历到后面,由于所有查找都是“向后”的,你可以完全去掉递归,直接查找给定索引的结果,因为它一定已经被计算过了。

这个比较模糊的解释有帮助吗?这其实是动态规划的基本应用,动态规划就是缓存的一个花哨说法;所以如果你感到困惑,可以去了解一下这个概念。

扩展到前 K 个最佳解决方案

如果你想找到前 K 个最佳解决方案,你需要一个复杂但可行的扩展,将标记的索引映射到最佳 K 个子集,而不是单个最佳子集 - 显然这会增加计算成本和一些额外的代码。基本上,不是选择 包含 还是不包含 token[J],而是取集合的并集,并在每个标记索引处修剪到 K 个最佳选项。如果直接实现,这将是 O(n log(n) + n k log(k))

3

一个简单的方法在第一步和第二步可以这样进行:首先,定义一个包含所有标记的列表和一个包含矛盾关系的字典(字典的每个键是一个标记,每个值是一个标记集合,表示与这个标记矛盾的标记)。然后,对于每个标记,执行两个操作:

  • 如果这个标记没有矛盾,就把它加入到 result 中,并且把与当前添加的标记矛盾的标记加入到 conflicting 集合中。
  • 如果这个标记有矛盾,就选择不把它加入到 result 中,直接跳到下一个标记。

下面是一个示例代码:

token_list = ['a', 'b', 'c']

contradictions = {
    'a': set(['b']),
    'b': set(['a']),
    'c': set()
}

class Generator(object):
    def __init__(self, token_list, contradictions):
        self.list = token_list
        self.contradictions = contradictions
        self.max_start = len(self.list) - 1

    def add_no(self, start, result, conflicting):
        if start < self.max_start:
            for g in self.gen(start + 1, result, conflicting):
                yield g
        else:
            yield result[:]

    def add_yes(self, token, start, result, conflicting):
        result.append(token)
        new_conflicting = conflicting | self.contradictions[token]
        for g in self.add_no(start, result, new_conflicting):
            yield g
        result.pop()

    def gen(self, start, result, conflicting):
        token = self.list[start]
        if token not in conflicting:
            for g in self.add_yes(token, start, result, conflicting):
                yield g
        for g in self.add_no(start, result, conflicting):
            yield g

    def go(self):
        return self.gen(0, [], set())

示例用法:

g = Generator(token_list, contradictions)
for x in g.go():
    print x

这是一个递归算法,所以它处理的标记数量不能超过几千个(因为Python的栈限制),但你可以很容易地创建一个非递归的版本。

2

获取所有不矛盾的令牌组的一个非常简单的方法:

#!/usr/bin/env python

token_list = ['a', 'b', 'c']

contradictions = {
    'a': set(['b']),
    'b': set(['a']),
    'c': set()
}

result = []

while token_list:
    token = token_list.pop()
    new = [set([token])]
    for r in result:
        if token not in contradictions or not r & contradictions[token]:
            new.append(r | set([token]))
    result.extend(new)

print result

撰写回答