Python itertools.permutations 的算法

34 投票
3 回答
9854 浏览
提问于 2025-04-15 21:09

有人能解释一下Python标准库2.6中itertools.permutations这个函数的算法吗?我不太明白它是怎么工作的。

代码如下:

def permutations(iterable, r=None):
    # permutations('ABCD', 2) --> AB AC AD BA BC BD CA CB CD DA DB DC
    # permutations(range(3)) --> 012 021 102 120 201 210
    pool = tuple(iterable)
    n = len(pool)
    r = n if r is None else r
    if r > n:
        return
    indices = range(n)
    cycles = range(n, n-r, -1)
    yield tuple(pool[i] for i in indices[:r])
    while n:
        for i in reversed(range(r)):
            cycles[i] -= 1
            if cycles[i] == 0:
                indices[i:] = indices[i+1:] + indices[i:i+1]
                cycles[i] = n - i
            else:
                j = cycles[i]
                indices[i], indices[-j] = indices[-j], indices[i]
                yield tuple(pool[i] for i in indices[:r])
                break
        else:
            return

3 个回答

1

用结果的模式来回答问题比用文字更简单(当然,如果你想了解理论中的数学部分,那另当别论),所以打印输出是最好的解释方式。
最微妙的地方在于,循环到最后后,它会重置到最后一轮的第一轮,然后开始下一次循环,或者不断重置到最后一轮的第一轮,甚至是更大的轮次,就像时钟一样。

负责重置的代码部分:

         if cycles[i] == 0:
             indices[i:] = indices[i+1:] + indices[i:i+1]
             cycles[i] = n - i

整体代码:

In [54]: def permutations(iterable, r=None):
    ...:     # permutations('ABCD', 2) --> AB AC AD BA BC BD CA CB CD DA DB DC
    ...:     # permutations(range(3)) --> 012 021 102 120 201 210
    ...:     pool = tuple(iterable)
    ...:     n = len(pool)
    ...:     r = n if r is None else r
    ...:     if r > n:
    ...:         return
    ...:     indices = range(n)
    ...:     cycles = range(n, n-r, -1)
    ...:     yield tuple(pool[i] for i in indices[:r])
    ...:     print(indices, cycles)
    ...:     while n:
    ...:         for i in reversed(range(r)):
    ...:             cycles[i] -= 1
    ...:             if cycles[i] == 0:
    ...:                 indices[i:] = indices[i+1:] + indices[i:i+1]
    ...:                 cycles[i] = n - i
    ...:                 print("reset------------------")
    ...:                 print(indices, cycles)
    ...:                 print("------------------")
    ...:             else:
    ...:                 j = cycles[i]
    ...:                 indices[i], indices[-j] = indices[-j], indices[i]
    ...:                 print(indices, cycles, i, n-j)
    ...:                 yield tuple(pool[i] for i in indices[:r])
    ...:                 break
    ...:         else:
    ...:             return

结果的一部分:

In [54]: list(','.join(i) for i in permutations('ABCDE', 3))
([0, 1, 2, 3, 4], [5, 4, 3])
([0, 1, 3, 2, 4], [5, 4, 2], 2, 3)
([0, 1, 4, 2, 3], [5, 4, 1], 2, 4)
reset------------------
([0, 1, 2, 3, 4], [5, 4, 3])
------------------
([0, 2, 1, 3, 4], [5, 3, 3], 1, 2)
([0, 2, 3, 1, 4], [5, 3, 2], 2, 3)
([0, 2, 4, 1, 3], [5, 3, 1], 2, 4)
reset------------------
([0, 2, 1, 3, 4], [5, 3, 3])
------------------
([0, 3, 1, 2, 4], [5, 2, 3], 1, 3)
([0, 3, 2, 1, 4], [5, 2, 2], 2, 3)
([0, 3, 4, 1, 2], [5, 2, 1], 2, 4)
reset------------------
([0, 3, 1, 2, 4], [5, 2, 3])
------------------
([0, 4, 1, 2, 3], [5, 1, 3], 1, 4)
([0, 4, 2, 1, 3], [5, 1, 2], 2, 3)
([0, 4, 3, 1, 2], [5, 1, 1], 2, 4)
reset------------------
([0, 4, 1, 2, 3], [5, 1, 3])
------------------
reset------------------(bigger reset)
([0, 1, 2, 3, 4], [5, 4, 3])
------------------
([1, 0, 2, 3, 4], [4, 4, 3], 0, 1)
([1, 0, 3, 2, 4], [4, 4, 2], 2, 3)
([1, 0, 4, 2, 3], [4, 4, 1], 2, 4)
reset------------------
([1, 0, 2, 3, 4], [4, 4, 3])
------------------
([1, 2, 0, 3, 4], [4, 3, 3], 1, 2)
([1, 2, 3, 0, 4], [4, 3, 2], 2, 3)
([1, 2, 4, 0, 3], [4, 3, 1], 2, 4)
2

最近我在重新实现排列算法的过程中,碰到了同样的问题,想和大家分享一下我对这个有趣算法的理解。

简而言之: 这个算法基于一个递归生成排列的算法(使用回溯和元素交换),并将其转化(或优化)为迭代形式。(这样做可能是为了提高效率,防止栈溢出)

基础知识

在开始之前,我需要确保我们使用的符号和原始算法一致。

  • n 表示可迭代对象的长度
  • r 表示一个输出排列元组的长度

还有一个简单的观察(正如 Alex 所讨论的):

  • 每当算法 yield 一个输出时,它只取 indices 列表的前 r 个元素。

cycles

首先,让我们讨论一下变量 cycles,并建立一些直观的理解。通过一些调试打印,我们可以看到 cycles 像是一个倒计时(类似于时间或时钟,比如 01:00:00 -> 00:59:59 -> 00:59:58):

  • 每个项目初始化为 range(n, n-r, -1),结果是 cycles[0]=n, cycles[1]=n-1...cycles[i]=n-i
  • 通常,只有最后一个元素会减少,每次减少(在 cycles[r-1] !=0 之后)都会产生一个输出(一个排列元组)。我们可以直观地称这种情况为 tick
  • 每当某个元素(假设是 cycles[i])减少到 0 时,会触发前一个元素(cycles[i-1])的减少。然后触发的元素(cycles[i])会恢复到它的初始值(n-i)。这种行为类似于借位减法,或者在倒计时中秒数到达 0 时分钟的重置。我们可以直观地称这个过程为 reset

为了进一步确认我们的直觉,可以在算法中添加一些打印语句,并用参数 iterable="ABCD", r=2 运行它。我们可以看到 cycles 变量的变化。注意方括号表示发生了一个“tick”,产生了输出,而花括号表示发生了一个“reset”,没有产生输出。

[4,3] -> [4,2] -> [4,1] -> {4,0} -> {4,3} -> 
[3,3] -> [3,2] -> [3,1] -> {3,0} -> {3,3} -> 
[2,3] -> [2,2] -> [2,1] -> {2,0} -> {2,3} -> 
[1,3] -> [1,2] -> [1,1] -> {1,0} -> {1,3} -> {0,3} -> {4,3}

利用 cycles 的初始值和变化模式,我们可以对 cycles 的 含义 做一个可能的解释:每个索引处剩余排列(输出)的数量。当初始化时,cycles[0]=n 表示在索引 0 处最初有 n 种可能选择,cycles[1]=n-1 表示在索引 1 处最初有 n-1 种可能选择,一直到 cycles[r-1]=n-r+1。这种对 cycles 的解释与数学相符,通过一些简单的组合数学计算,我们可以确认确实如此。另一个支持证据是,每当算法结束时,我们会有 P(n,r)(P(n,r)=n*(n-1)*...*(n-r+1))个 ticks(将进入 while 之前的初始 yield 也算作一个 tick)。

indices

现在我们来讨论更复杂的部分,indices 列表。由于这本质上是一个递归算法(更准确地说是回溯),我想从一个子问题开始(i=r-1):当 indices 中从索引 0 到索引 r-2(包括) 的值固定时,只有索引 r-1(换句话说,就是 indices 中的最后一个元素)的值在变化。同时,我会引入一个具体的例子(iterable="ABCDE", r=3),我们将专注于它如何生成前 3 个输出:ABC、ABD、ABE。

  • 根据这个子问题,我们将 indices 列表分成 3 部分,并给它们命名,
    • 固定部分 : indices[0:r-2](包括)
    • 变化部分: indices[r-1](只有一个值)
    • 待处理部分: indices[r:n-1](除了前两部分的剩余部分)
  • 由于这是一个回溯算法,我们需要保持一个不变的条件在执行前后都不变。这个不变的条件是
    • 子列表包含变化部分和待处理部分(indices[r-1:n-1]),在执行过程中会被修改,但结束时会恢复。
  • 现在我们可以转向 cyclesindices 在神秘的 while 循环中的交互。一些操作已经被 Alex 概述,我会进一步详细说明。
    • 在每个 tick 中,变化部分的元素与待处理部分的某个元素交换,并且待处理部分的相对顺序保持不变。
      • 用字符来可视化 indices,花括号突出显示待处理部分:
      • ABC{DE} -> ABD{CE} -> ABE{CD}
    • 当发生 reset 时,变化部分的元素被移动到待处理部分的后面,从而恢复子列表的初始布局(包含变化部分和待处理部分)
      • 用字符来可视化 indices,花括号突出显示变化部分:
      • AB{E}CD -> ABCD{E}
  • 在这个执行过程中(i=r-1),只有 tick 阶段可以产生输出,并且总共会产生 n-r+1 个输出,这与 cycles[i] 的初始值相匹配。这也是因为在固定部分固定的情况下,我们只能有 n-r+1 种排列选择。
  • cycles[i] 减少到 0 时,reset 阶段开始,重置 cycles[i]n-r+1 并恢复不变的子列表。这个阶段标志着这个执行的结束,并表明所有可能的排列选择都已经输出。
  • 因此,我们已经证明,在这个子问题(i=r-1)中,这个算法确实是一个有效的回溯算法,因为它
    • 在给定前提(固定前缀部分)的情况下输出所有可能的值
    • 保持不变的条件不变(在 reset 阶段恢复)
  • 这个证明(?) 也可以推广到其他 i 的值,从而证明(?) 这个排列生成算法的正确性。

重新实现

呼!这真是一段长长的阅读,你可能还想对算法进行更多的调整(更多的 print)以完全信服。实际上,我们可以将算法的基本原理简化为以下伪代码:

// precondition: the fixed part (or prefix) is fixed
OUTPUT initial_permutation // also invokes the next level
WHILE remaining_permutation_count > 0
    // tick
    swap the changing element with an element in backlog
    OUTPUT current_permutation // also invokes the next level
// reset
move the changing element behind the backlog

这里是一个使用简单回溯的 Python 实现:

# helpers
def swap(list, i, j):
    list[i], list[j] = list[j], list[i]

def move_to_last(list, i):
    list[i:] = list[i+1:] + [list[i]]

def print_first_n_element(list, n):
    print("".join(list[:n]))

# backtracking dfs
def permutations(list, r, changing_index):
    if changing_index == r:
        # we've reached the deepest level
        print_first_n_element(list, r)
        return
    
    # a pseudo `tick`
    # process initial permutation
    # which is just doing nothing (using the initial value)
    permutations(list, r, changing_index + 1)

    # note: initial permutaion has been outputed, thus the minus 1
    remaining_choices = len(list) - 1 - changing_index
    # for (i=1;i<=remaining_choices;i++)
    for i in range(1, remaining_choices+1):
        # `tick` phases
        
        # make one swap
        swap_idx = changing_index + i
        swap(list, changing_index, swap_idx)
        # finished one move at current level, now go deeper
        permutations(list, r, changing_index + 1)
    
    # `reset` phase
    move_to_last(list, changing_index)

# wrapper
def permutations_wrapper(list, r):
    permutations(list, r, 0)

# main
if __name__ == "__main__":
    my_list = ["A", "B", "C", "D"]
    permutations_wrapper(my_list, 2)

现在剩下的步骤就是展示回溯版本与 itertools 源代码中的迭代版本是等价的。一旦你理解了这个算法的工作原理,这应该是相当简单的。按照各种计算机科学教科书的传统,这留给读者作为练习。

39

你需要了解一下数学理论中的排列循环,也叫做“轨道”。了解这两个术语很重要,因为这个数学主题,组合数学的核心,比较复杂,你可能需要查阅一些研究论文,这些论文可能会用到这两个术语中的一个或两个。

如果你想更简单地了解排列的理论,可以去维基百科看看。上面提到的每个链接都有不错的参考资料,如果你对组合数学产生了兴趣,想深入了解的话,可以去看看(我个人就是这样,后来变成了我的一个爱好;-)。

一旦你理解了数学理论,代码的细节依然很有趣,可以“逆向工程”。显然,indices就是当前排列的索引,给出的项目总是由以下内容决定:

yield tuple(pool[i] for i in indices[:r])

所以,这个有趣的机制的核心是cycles,它表示排列的轨道,并导致indices的更新,主要通过以下语句:

j = cycles[i]
indices[i], indices[-j] = indices[-j], indices[i]

也就是说,如果cycles[i]j,这意味着接下来更新索引时,要把第(从左数)和第j(从右数)的交换(例如,如果j是1,那么就是交换indices的最后一个元素——indices[-1])。然后还有一种不太常见的“批量更新”,当cycles中的某个项在递减时变为0:

indices[i:] = indices[i+1:] + indices[i:i+1]
cycles[i] = n - i

这会把indices的第项放到最后,所有后面的项都向左移动一位,并表示下次我们再遇到这个cycles的项时,会把新的第(从左数)和第n - i个(从右数)交换——这又是第,当然,除了在我们下次查看它之前会有一个

cycles[i] -= 1

的操作;-)。

最难的部分当然是证明这个方法是有效的——也就是说,所有的排列都被彻底生成,没有重叠,并且正确“定时”退出。我觉得,与其证明,不如看看这个机制在简单情况下是如何工作的——把yield语句注释掉,换成print语句(Python 2.*),我们可以得到:

def permutations(iterable, r=None):
    # permutations('ABCD', 2) --> AB AC AD BA BC BD CA CB CD DA DB DC
    # permutations(range(3)) --> 012 021 102 120 201 210
    pool = tuple(iterable)
    n = len(pool)
    r = n if r is None else r
    if r > n:
        return
    indices = range(n)
    cycles = range(n, n-r, -1)
    print 'I', 0, cycles, indices
    # yield tuple(pool[i] for i in indices[:r])
    print indices[:r]
    while n:
        for i in reversed(range(r)):
            cycles[i] -= 1
            if cycles[i] == 0:
        print 'B', i, cycles, indices
                indices[i:] = indices[i+1:] + indices[i:i+1]
                cycles[i] = n - i
        print 'A', i, cycles, indices
            else:
        print 'b', i, cycles, indices
                j = cycles[i]
                indices[i], indices[-j] = indices[-j], indices[i]
        print 'a', i, cycles, indices
                # yield tuple(pool[i] for i in indices[:r])
            print indices[:r]
                break
        else:
            return

permutations('ABC', 2)

运行这个会显示:

I 0 [3, 2] [0, 1, 2]
[0, 1]
b 1 [3, 1] [0, 1, 2]
a 1 [3, 1] [0, 2, 1]
[0, 2]
B 1 [3, 0] [0, 2, 1]
A 1 [3, 2] [0, 1, 2]
b 0 [2, 2] [0, 1, 2]
a 0 [2, 2] [1, 0, 2]
[1, 0]
b 1 [2, 1] [1, 0, 2]
a 1 [2, 1] [1, 2, 0]
[1, 2]
B 1 [2, 0] [1, 2, 0]
A 1 [2, 2] [1, 0, 2]
b 0 [1, 2] [1, 0, 2]
a 0 [1, 2] [2, 0, 1]
[2, 0]
b 1 [1, 1] [2, 0, 1]
a 1 [1, 1] [2, 1, 0]
[2, 1]
B 1 [1, 0] [2, 1, 0]
A 1 [1, 2] [2, 0, 1]
B 0 [0, 2] [2, 0, 1]
A 0 [3, 2] [0, 1, 2]

关注一下cycles:它们开始时是3, 2——然后最后一个递减,所以变成3, 1——最后一个还没变成零,所以我们有一个“小”事件(在索引中交换一次),然后中断内层循环。接着我们再次进入,这次最后一个递减变成3, 0——最后一个现在是零,所以这是一个“大”事件——在索引中进行“批量交换”(虽然这里没有太多的批量,但可能会有;-),然后cycles又回到了3, 2。但现在我们没有中断外层循环,所以继续递减倒数第二个(在这个例子中是第一个)——这又是一个小事件,在索引中交换一次,然后再次中断内层循环。回到循环中,最后一个再次递减,这次变成2, 1——小事件,等等。最终会出现一个完整的外层循环,只有大事件,没有小事件——那时cycles的所有项都是1,所以递减会让每个都变成零(大事件),在最后一个循环中不会发生yield

由于在那个循环中从未执行过break,我们就进入forelse分支,返回。注意while n可能有点误导:它实际上是作为一个while True来运行——n从未改变,while循环只会因为那个return语句而退出;它也可以写成if not n: return,然后是while True:,因为当n0(空“池”)时,第一次,简单的空yield之后就没有更多的内容可以返回。作者只是决定通过把if not n:的检查和while合并来节省几行代码;-)。

我建议你继续检查几个具体的案例——最终你应该能感受到这个“机械装置”的运作。开始时只关注cycles(也许可以相应地编辑print语句,去掉indices),因为它们在轨道上的机械般进展是这个微妙而深刻的算法的关键;一旦你理解了这一点indices如何根据cycles的顺序正确更新几乎就成了小事一桩!-)

撰写回答