扁平化嵌套循环 / 降低复杂度 - 互补对计数算法

6 投票
5 回答
2716 浏览
提问于 2025-04-17 10:10

我最近在用Python解决一个任务,发现了一个看起来复杂度是O(n log n)的解决方案,但我觉得对于某些输入来说,这个方法效率很低(比如第一个参数是0,而pairs是一个很长的零列表)。

这个方法还有三层for循环。我觉得可以优化一下,但目前我还没找到更好的办法,可能只是忽略了一些明显的东西;)

基本上,问题是这样的:

给定一个整数列表(values),这个函数需要返回满足以下条件的索引对的数量:

  • 假设一个索引对是一个像(index1, index2)的元组,
  • 那么values[index1] == complementary_diff - values[index2]这个条件成立。

举个例子: 如果给定一个列表[1, 3, -4, 0, -3, 5]作为values,并且1作为complementary_diff,那么这个函数应该返回4(这就是以下索引对列表的长度:[(0, 3), (2, 5), (3, 0), (5, 2)])。

这是我目前的代码,它在大多数情况下应该能正常工作,但正如我所说的,在某些情况下它可能会运行得很慢,尽管它的复杂度大致是O(n log n)(看起来最坏的情况复杂度是O(n^2))。

def complementary_pairs_number (complementary_diff, values):
    value_key = {} # dictionary storing indexes indexed by values
    for index, item in enumerate(values):
        try:
            value_key[item].append(index)
        except (KeyError,): # the item has not been found in value_key's keys
            value_key[item] = [index]
    key_pairs = set() # key pairs are unique by nature
    for pos_value in value_key: # iterate through keys of value_key dictionary
        sym_value = complementary_diff - pos_value
        if sym_value in value_key: # checks if the symmetric value has been found
            for i1 in value_key[pos_value]: # iterate through pos_values' indexes
                for i2 in value_key[sym_value]: # as above, through sym_values
                    # add indexes' pairs or ignore if already added to the set
                    key_pairs.add((i1, i2))
                    key_pairs.add((i2, i1))
    return len(key_pairs)

对于给定的例子,它的表现是这样的:

>>> complementary_pairs_number(1, [1, 3, -4, 0, -3, 5])
4

如果你看到代码可以怎么“简化”或“扁平化”,请告诉我。

我不确定仅仅检查complementary_diff == 0等是否是最佳方法——如果你觉得是,请告诉我。

编辑:我已经修正了例子(谢谢,unutbu!)。

5 个回答

0

修改了@unutbu提供的解决方案:

这个问题可以简化为比较这两个字典:

  1. values

  2. 一个预先计算好的字典,用于存储(补充差值 - values[i])

    def complementary_pairs_number(complementary_diff, values):
        value_key = {} # dictionary storing indexes indexed by values
        for index, item in enumerate(values):
            value_key.setdefault(item,[]).append(index)
    
        answer_key = {} # dictionary storing indexes indexed by (complementary_diff - values)
        for index, item in enumerate(values):
            answer_key.setdefault((complementary_diff-item),[]).append(index)
    
        num_pairs = 0
        print(value_key)
        print(answer_key)
        for pos_value in value_key: 
            if pos_value in answer_key: 
                num_pairs+=len(value_key[pos_value])*len(answer_key[pos_value])
        return num_pairs
    
2

你可以看看一些函数式编程的写法,比如说 reduce 之类的。

很多时候,处理嵌套数组的逻辑可以通过使用像 reduce、map、reject 这样的函数来简化。

如果想看个例子(用 JavaScript 的话),可以看看 underscore js。我对 Python 不是特别了解,所以不太清楚他们有哪些库可以用。

4

我觉得这样可以把复杂度提高到 O(n)

  • value_key.setdefault(item,[]).append(index) 比用 try..except 语句快。而且它也比用 collections.defaultdict(list) 快。(我用 ipython 的 %timeit 测试过。)
  • 原来的代码每个解都访问了两次。对于 value_key 中的每个 pos_value,都有一个唯一的 sym_value 和它关联。当 sym_value 也在 value_key 中时,会有解。但是当我们遍历 value_key 的键时,pos_value 最终会被赋值为 sym_value 的值,这样就导致代码重复计算已经做过的事情。所以如果能阻止 pos_value 等于之前的 sym_value,就可以把工作量减半。我用 seen = set() 来跟踪已经见过的 sym_value
  • 这段代码只关心 len(key_pairs),而不是 key_pairs 本身。因此,我们可以简单地跟踪计数(用 num_pairs),而不是跟踪这些对(用 set)。所以我们可以把两个内部的 for 循环替换为

    num_pairs += 2*len(value_key[pos_value])*len(value_key[sym_value])
    

    在“唯一对角线”的情况下,即 pos_value == sym_value,可以减少一半的计算。


def complementary_pairs_number(complementary_diff, values):
    value_key = {} # dictionary storing indexes indexed by values
    for index, item in enumerate(values):
        value_key.setdefault(item,[]).append(index)
    # print(value_key)
    num_pairs = 0
    seen = set()
    for pos_value in value_key: 
        if pos_value in seen: continue
        sym_value = complementary_diff - pos_value
        seen.add(sym_value)
        if sym_value in value_key: 
            # print(pos_value, sym_value, value_key[pos_value],value_key[sym_value])
            n = len(value_key[pos_value])*len(value_key[sym_value])
            if pos_value == sym_value:
                num_pairs += n
            else:
                num_pairs += 2*n
    return num_pairs

撰写回答