获取字典中与其他键重叠的所有键

8 投票
2 回答
1864 浏览
提问于 2025-04-17 18:23

我有一个列表推导式,长这样:

cart = [ ((p,pp),(q,qq)) for ((p,pp),(q,qq))\
         in itertools.product(C.items(), repeat=2)\
         if p[1:] == q[:-1] ]

这里的C是一个字典,字典里的键是一些任意整数的元组。所有的元组长度都是一样的。在最坏的情况下,新的列表里应该包含所有的组合。这种情况发生得还挺频繁的。

举个例子,我有一个这样的字典:

C = { (0,1):'b',
      (2,0):'c',
      (0,0):'d' }

我想要的结果是:

cart = [ (((2, 0), 'c'), ((0, 1), 'b'))
         (((2, 0), 'c'), ((0, 0), 'd'))
         (((0, 0), 'd'), ((0, 1), 'b'))
         (((0, 0), 'd'), ((0, 0), 'd')) ]

这里提到的“重叠”是指,比如说元组(1,2,3,4)(2,3,4,5)之间有重叠部分(2,3,4)。重叠的部分必须在元组的“边缘”上。我只想要那些长度比元组长度少一个的重叠部分。因此(1,2,3,4)(3,4,5,6)之间没有重叠。另外,注意当我们去掉元组的第一个或最后一个元素时,可能会得到不唯一的元组,这些元组都需要和其他元素进行比较。这一点在我第一个例子中没有强调。

我的代码大部分执行时间都花在这个列表推导式上。我总是需要cart里的所有元素,所以用生成器似乎没有加快速度。

我的问题是:有没有更快的方法来做到这一点?

我想到的一个办法是,我可以尝试创建两个新的字典,像这样:

aa = defaultdict(list)
bb = defaultdict(list)
[aa[p[1:]].append(p) for p in C.keys()]
[bb[p[:-1]].append(p) for p in C.keys()]

然后以某种方式将aa[i]列表中的所有元素组合与bb[i]列表中的元素合并,但我似乎也无法完全理解这个想法。

更新

tobias_k和shx2提供的解决方案在复杂度上比我原来的代码要好(就我所知)。我的代码是O(n^2),而另外两个解决方案是O(n)。不过对于我的问题规模和组成,三种解决方案的运行时间似乎差不多。我想这和函数调用的开销以及我处理的数据的性质有关。特别是不同键的数量,以及键的实际组成,似乎对运行时间有很大影响。我知道这一点是因为当键完全随机时,代码运行得要慢得多。我接受了tobias_k的答案,因为他的代码最容易理解。不过,我仍然非常欢迎其他关于如何完成这个任务的建议。

2 个回答

1

你把数据处理成字典的想法真不错。接下来就是这个:

from itertools import groupby
C = { (0,1): 'b', (2,0): 'c', (0,0): 'd' }

def my_groupby(seq, key):
    """
     >>> group_by(range(10), lambda x: 'mod=%d' % (x % 3))
     {'mod=2': [2, 5, 8], 'mod=0': [0, 3, 6, 9], 'mod=1': [1, 4, 7]}
    """
    groups = dict()
    for x in seq:
        y = key(x)
        groups.setdefault(y, []).append(x)
    return groups

def get_overlapping_items(C):
    prefixes = my_groupby(C.iteritems(), key = lambda (k,v): k[:-1])
    for k1, v1 in C.iteritems():
        prefix = k1[1:]
        for k2, v2 in prefixes.get(prefix, []):
            yield (k1, v1), (k2, v2)

for x in get_overlapping_items(C):
    print x     

(((2, 0), 'c'), ((0, 1), 'b'))
(((2, 0), 'c'), ((0, 0), 'd'))
(((0, 0), 'd'), ((0, 1), 'b'))
(((0, 0), 'd'), ((0, 0), 'd'))

顺便说一下,不要用:

itertools.product(*[C.items()]*2)

而是用:

itertools.product(C.items(), repeat=2)
2

你其实走在正确的路上,使用字典来存储所有键的前缀。不过,要记住(根据我对问题的理解),两个键也可以重叠,只要它们的重叠部分少于 len-1。比如,键 (1,2,3,4)(3,4,5,6) 也是会重叠的。因此,我们需要创建一个映射,保存所有键的前缀。(如果我理解错了这个问题,可以直接去掉里面的两个 for 循环。)一旦我们有了这个映射,我们就可以再遍历一次所有的键,检查它们的后缀中是否有与 prefixes 映射中匹配的键。(更新:由于键可以在多个前缀/后缀上重叠,我们将重叠的对存储在一个集合中。)

def get_overlap(keys):
    # create map: prefix -> set(keys with that prefix)
    prefixes = defaultdict(set)
    for key in keys:
        for prefix in [key[:i] for i in range(len(key))]:
            prefixes[prefix].add(key)
    # get keys with matching prefixes for all suffixes
    overlap = set()
    for key in keys:
        for suffix in [key[i:] for i in range(len(key))]:
            overlap.update([(key, other) for other in prefixes[suffix]
                                                      if other != key])
    return overlap

(注意,为了简单起见,我只关心字典中的键,而不关心值。扩展这个功能以返回值,或者将其作为后处理步骤,应该很简单。)

总体运行时间应该只有 2*n*k,其中 n 是键的数量,k 是键的长度。如果有很多键具有相同的前缀,空间复杂度(即 prefixes 映射的大小)应该在 n*kn^2*k 之间。

注意:上面的回答是针对更一般的情况,即重叠区域可以有 任何 长度。对于更简单的情况,如果你只考虑比原始元组短一个的重叠,下面的内容应该就足够了,并能得到你示例中描述的结果:

def get_overlap_simple(keys):
    prefixes = defaultdict(list)
    for key in keys:
        prefixes[key[:-1]].append(key)
    return [(key, other) for key in keys for other in prefixes[key[1:]]]

撰写回答