扁平化嵌套循环 / 降低复杂度 - 互补对计数算法
我最近在用Python解决一个任务,发现了一个看起来复杂度是O(n log n)的解决方案,但我觉得对于某些输入来说,这个方法效率很低(比如第一个参数是0
,而pairs
是一个很长的零列表)。
这个方法还有三层for
循环。我觉得可以优化一下,但目前我还没找到更好的办法,可能只是忽略了一些明显的东西;)
基本上,问题是这样的:
给定一个整数列表(
values
),这个函数需要返回满足以下条件的索引对的数量:
- 假设一个索引对是一个像
(index1, index2)
的元组,- 那么
values[index1] == complementary_diff - values[index2]
这个条件成立。举个例子: 如果给定一个列表
[1, 3, -4, 0, -3, 5]
作为values
,并且1
作为complementary_diff
,那么这个函数应该返回4
(这就是以下索引对列表的长度:[(0, 3), (2, 5), (3, 0), (5, 2)]
)。
这是我目前的代码,它在大多数情况下应该能正常工作,但正如我所说的,在某些情况下它可能会运行得很慢,尽管它的复杂度大致是O(n log n)(看起来最坏的情况复杂度是O(n^2))。
def complementary_pairs_number (complementary_diff, values):
value_key = {} # dictionary storing indexes indexed by values
for index, item in enumerate(values):
try:
value_key[item].append(index)
except (KeyError,): # the item has not been found in value_key's keys
value_key[item] = [index]
key_pairs = set() # key pairs are unique by nature
for pos_value in value_key: # iterate through keys of value_key dictionary
sym_value = complementary_diff - pos_value
if sym_value in value_key: # checks if the symmetric value has been found
for i1 in value_key[pos_value]: # iterate through pos_values' indexes
for i2 in value_key[sym_value]: # as above, through sym_values
# add indexes' pairs or ignore if already added to the set
key_pairs.add((i1, i2))
key_pairs.add((i2, i1))
return len(key_pairs)
对于给定的例子,它的表现是这样的:
>>> complementary_pairs_number(1, [1, 3, -4, 0, -3, 5])
4
如果你看到代码可以怎么“简化”或“扁平化”,请告诉我。
我不确定仅仅检查complementary_diff == 0
等是否是最佳方法——如果你觉得是,请告诉我。
编辑:我已经修正了例子(谢谢,unutbu!)。
5 个回答
修改了@unutbu提供的解决方案:
这个问题可以简化为比较这两个字典:
values
一个预先计算好的字典,用于存储(补充差值 - values[i])
def complementary_pairs_number(complementary_diff, values): value_key = {} # dictionary storing indexes indexed by values for index, item in enumerate(values): value_key.setdefault(item,[]).append(index) answer_key = {} # dictionary storing indexes indexed by (complementary_diff - values) for index, item in enumerate(values): answer_key.setdefault((complementary_diff-item),[]).append(index) num_pairs = 0 print(value_key) print(answer_key) for pos_value in value_key: if pos_value in answer_key: num_pairs+=len(value_key[pos_value])*len(answer_key[pos_value]) return num_pairs
你可以看看一些函数式编程的写法,比如说 reduce 之类的。
很多时候,处理嵌套数组的逻辑可以通过使用像 reduce、map、reject 这样的函数来简化。
如果想看个例子(用 JavaScript 的话),可以看看 underscore js。我对 Python 不是特别了解,所以不太清楚他们有哪些库可以用。
我觉得这样可以把复杂度提高到 O(n)
:
value_key.setdefault(item,[]).append(index)
比用try..except
语句快。而且它也比用collections.defaultdict(list)
快。(我用 ipython 的 %timeit 测试过。)- 原来的代码每个解都访问了两次。对于
value_key
中的每个pos_value
,都有一个唯一的sym_value
和它关联。当sym_value
也在value_key
中时,会有解。但是当我们遍历value_key
的键时,pos_value
最终会被赋值为sym_value
的值,这样就导致代码重复计算已经做过的事情。所以如果能阻止pos_value
等于之前的sym_value
,就可以把工作量减半。我用seen = set()
来跟踪已经见过的sym_value
。 这段代码只关心
len(key_pairs)
,而不是key_pairs
本身。因此,我们可以简单地跟踪计数(用num_pairs
),而不是跟踪这些对(用set
)。所以我们可以把两个内部的 for 循环替换为num_pairs += 2*len(value_key[pos_value])*len(value_key[sym_value])
在“唯一对角线”的情况下,即
pos_value == sym_value
,可以减少一半的计算。
def complementary_pairs_number(complementary_diff, values):
value_key = {} # dictionary storing indexes indexed by values
for index, item in enumerate(values):
value_key.setdefault(item,[]).append(index)
# print(value_key)
num_pairs = 0
seen = set()
for pos_value in value_key:
if pos_value in seen: continue
sym_value = complementary_diff - pos_value
seen.add(sym_value)
if sym_value in value_key:
# print(pos_value, sym_value, value_key[pos_value],value_key[sym_value])
n = len(value_key[pos_value])*len(value_key[sym_value])
if pos_value == sym_value:
num_pairs += n
else:
num_pairs += 2*n
return num_pairs