生成无相邻相等元素的列表所有排列

88 投票
12 回答
12921 浏览
提问于 2025-04-18 17:08

当我们对一个列表进行排序时,比如说

a = [1,2,3,3,2,2,1]
sorted(a) => [1, 1, 2, 2, 2, 3, 3]

相同的元素在结果列表中总是相邻的。

那么,我该如何做相反的事情——打乱这个列表,使得相同的元素尽量不相邻(或者尽可能少相邻)呢?

例如,对于上面的列表,可能的解决方案之一是

p = [1,3,2,3,2,1,2]

更正式地说,给定一个列表 a,生成一个排列 p,使得相邻的相同元素对 p[i]==p[i+1] 的数量最小。

由于列表很大,生成和过滤所有的排列是不现实的。

附加问题:如何高效地生成所有这样的排列?

这是我用来测试解决方案的代码:https://gist.github.com/gebrkn/9f550094b3d24a35aebd

更新:在这里选择一个最佳答案很困难,因为很多人都提供了很好的答案。@VincentvanderWeele@David Eisenstat@Coady@enrico.bacis@srgerg 提供了生成最佳排列的函数,效果非常好。@tobias_k 和 David 也回答了附加问题(生成所有排列)。David 在正确性证明方面也得到了额外的分数。

@VincentvanderWeele 的代码似乎是最快的。

12 个回答

5

在Python中,你可以这样做。

假设你有一个已经排好序的列表 l,你可以执行以下操作:

length = len(l)
odd_ind = length%2
odd_half = (length - odd_ind)/2
for i in range(odd_half)[::2]:
    my_list[i], my_list[odd_half+odd_ind+i] = my_list[odd_half+odd_ind+i], my_list[i]

这些操作都是在原地进行的,所以应该会比较快(时间复杂度是O(N))。注意,你会从 l[i] == l[i+1] 转变为 l[i] == l[i+2],所以最后得到的顺序并不是随机的,但根据我对问题的理解,你并不需要随机性。

这个方法的思路是把排好序的列表从中间分开,然后交换两个部分中的每个元素。

比如说,对于 l= [1, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5],这样处理后会变成 l = [3, 1, 4, 2, 5, 1, 3, 1, 4, 2, 5]

这个方法无法完全消除所有的 l[i] == l[i + 1],当某个元素的数量大于或等于列表长度的一半时。

虽然上述方法在最常见的元素数量小于列表长度一半时效果很好,下面这个函数也能处理极限情况(著名的越界问题),在这种情况下,从第一个元素开始的每个其他元素必须是数量最多的那个:

def no_adjacent(my_list):
    my_list.sort()
    length = len(my_list)
    odd_ind = length%2
    odd_half = (length - odd_ind)/2
    for i in range(odd_half)[::2]:
        my_list[i], my_list[odd_half+odd_ind+i] = my_list[odd_half+odd_ind+i], my_list[i]

    #this is just for the limit case where the abundance of the most frequent is half of the list length
    if max([my_list.count(val) for val in set(my_list)]) + 1 - odd_ind > odd_half:
        max_val = my_list[0]
        max_count = my_list.count(max_val)
        for val in set(my_list):
            if my_list.count(val) > max_count:
               max_val = val
               max_count = my_list.count(max_val)
        while max_val in my_list:
            my_list.remove(max_val)
        out = [max_val]
        max_count -= 1
        for val in my_list:
            out.append(val)
            if max_count:
                out.append(max_val)
                max_count -= 1
        if max_count:
            print 'this is not working'
            return my_list
            #raise Exception('not possible')
        return out
    else:
        return my_list
8

你可以通过一种叫做递归回溯的算法来生成所有的“完全无序”排列(也就是没有两个相同的元素挨在一起)。其实,生成所有排列的唯一不同之处在于,你需要记录下最后一个数字,并相应地排除一些解。

def unsort(lst, last=None):
    if lst:
        for i, e in enumerate(lst):
            if e != last:
                for perm in unsort(lst[:i] + lst[i+1:], e):
                    yield [e] + perm
    else:
        yield []

需要注意的是,这种方式效率不是很高,因为它会创建很多子列表。此外,我们可以通过优先考虑那些限制条件最严格的数字(也就是出现次数最多的数字)来加快速度。下面是一个更高效的版本,只使用数字的counts

def unsort_generator(lst, sort=False):
    counts = collections.Counter(lst)
    def unsort_inner(remaining, last=None):
        if remaining > 0:
            # most-constrained first, or sorted for pretty-printing?
            items = sorted(counts.items()) if sort else counts.most_common()
            for n, c in items:
                if n != last and c > 0:
                    counts[n] -= 1   # update counts
                    for perm in unsort_inner(remaining - 1, n):
                        yield [n] + perm
                    counts[n] += 1   # revert counts
        else:
            yield []
    return unsort_inner(len(lst))

你可以用这个方法生成下一个完美排列,或者生成一个包含所有排列的list。不过要注意,如果没有“完全无序”的排列,那么这个生成器就会返回“没有”结果。

>>> lst = [1,2,3,3,2,2,1]
>>> next(unsort_generator(lst))
[2, 1, 2, 3, 1, 2, 3]
>>> list(unsort_generator(lst, sort=True))
[[1, 2, 1, 2, 3, 2, 3], 
 ... 36 more ...
 [3, 2, 3, 2, 1, 2, 1]]
>>> next(unsort_generator([1,1,1]))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
StopIteration

为了解决这个问题,你可以把这个方法和其他答案中提到的算法结合使用,作为备用方案。如果存在一个完美无序的排列,这样做可以确保返回它;如果没有,就会返回一个不错的近似结果。

def unsort_safe(lst):
    try:
        return next(unsort_generator(lst))
    except StopIteration:
        return unsort_fallback(lst)
10

这个算法的思路是:每次选择剩下的最常见的物品,但不能和上一个选择的物品相同,这个方法是对的。下面是一个简单的实现方式,它巧妙地使用了一个堆来跟踪最常见的物品。

import collections, heapq
def nonadjacent(keys):
    heap = [(-count, key) for key, count in collections.Counter(a).items()]
    heapq.heapify(heap)
    count, key = 0, None
    while heap:
        count, key = heapq.heapreplace(heap, (count, key)) if count else heapq.heappop(heap)
        yield key
        count += 1
    for index in xrange(-count):
        yield key

>>> a = [1,2,3,3,2,2,1]
>>> list(nonadjacent(a))
[2, 1, 2, 3, 1, 2, 3]
23

伪代码:

  1. 对列表进行排序
  2. 遍历排序后列表的前半部分,把所有偶数位置填入结果列表
  3. 遍历排序后列表的后半部分,把所有奇数位置填入结果列表

只有当输入中超过一半的元素都是同一个时,p[i]==p[i+1]才会成立。在这种情况下,没办法避免把同一个元素放在相邻的位置(这可以用鸽巢原理来解释)。


正如评论中提到的,这种方法在某些情况下可能会出现一个问题:如果某个元素出现的次数至少是总数的 n/2(或者对于奇数 nn/2+1),那么可能会有太多冲突。最多会有两个这样的元素,如果有两个,算法就能正常工作。唯一的问题是当有一个元素出现的次数至少占了一半。我们可以通过先找到这个元素并优先处理它来解决这个问题。

我对Python不太了解,所以我参考了原作者在GitHub上之前版本的实现:

# Sort the list
a = sorted(lst)

# Put the element occurring more than half of the times in front (if needed)
n = len(a)
m = (n + 1) // 2
for i in range(n - m + 1):
    if a[i] == a[i + m - 1]:
        a = a[i:] + a[:i]
        break

result = [None] * n

# Loop over the first half of the sorted list and fill all even indices of the result list
for i, elt in enumerate(a[:m]):
    result[2*i] = elt

# Loop over the second half of the sorted list and fill all odd indices of the result list
for i, elt in enumerate(a[m:]):
    result[2*i+1] = elt

return result
31

这段内容是关于一个算法的,主要是说如何选择剩下的物品类型中最常见的,除非刚刚选过的那种。你可以参考一下Coady的实现,它展示了这个算法的具体做法。

import collections
import heapq


class Sentinel:
    pass


def david_eisenstat(lst):
    counts = collections.Counter(lst)
    heap = [(-count, key) for key, count in counts.items()]
    heapq.heapify(heap)
    output = []
    last = Sentinel()
    while heap:
        minuscount1, key1 = heapq.heappop(heap)
        if key1 != last or not heap:
            last = key1
            minuscount1 += 1
        else:
            minuscount2, key2 = heapq.heappop(heap)
            last = key2
            minuscount2 += 1
            if minuscount2 != 0:
                heapq.heappush(heap, (minuscount2, key2))
        output.append(last)
        if minuscount1 != 0:
            heapq.heappush(heap, (minuscount1, key1))
    return output

正确性证明

对于两种物品类型,假设它们的数量分别是k1和k2,最优解的缺陷数量如下:如果k1小于k2,缺陷数量是k2 - k1 - 1;如果k1等于k2,缺陷数量是0;如果k1大于k2,缺陷数量是k1 - k2 - 1。显然,当k1等于k2时,情况很明显。其他情况是对称的;少数元素的每个实例最多会导致两个缺陷,而总的可能缺陷数量是k1 + k2 - 1。

这个贪心算法能返回最优解,原因如下。我们称一个前缀(部分解)为安全,如果它能扩展成一个最优解。显然,空前缀是安全的,如果一个安全的前缀已经是完整的解,那么这个解就是最优的。我们只需要通过归纳法证明每一步贪心选择都能保持安全。

贪心选择引入缺陷的唯一情况是只剩下一种物品类型,这样就只有一种继续的方式,而这个方式是安全的。否则,设P是考虑的步骤之前的(安全)前缀,P'是步骤之后的前缀,S是扩展P的最优解。如果S也扩展了P',那我们就完成了。否则,设P' = Px,S = PQ,Q = yQ',其中x和y是物品,Q和Q'是序列。

首先假设P不以y结尾。根据算法的选择,x在Q中的出现频率至少和y一样多。考虑Q中只包含x和y的最大子串。如果第一个子串中x的数量至少和y一样多,那么可以重写这个子串,使其以x开始而不引入额外的缺陷。如果第一个子串中y的数量多于x,那么其他子串中x的数量一定多于y,我们可以重写这些子串,使x在前。无论哪种情况,我们都能找到一个扩展P'的最优解T。

现在假设P以y结尾。我们可以通过将x的第一次出现移动到前面来修改Q。这样做最多会引入一个缺陷(在x原来的位置),同时消除一个缺陷(yy)。

生成所有解

这是tobias_k的回答,加上高效的测试来检测当前选择是否在某种情况下受到全局限制。这个算法的渐进运行时间是最优的,因为生成的开销大约是输出长度的数量级。不幸的是,最坏情况下的延迟是平方级的;如果使用更好的数据结构,可以将其减少到线性(最优)。

from collections import Counter
from itertools import permutations
from operator import itemgetter
from random import randrange


def get_mode(count):
    return max(count.items(), key=itemgetter(1))[0]


def enum2(prefix, x, count, total, mode):
    prefix.append(x)
    count_x = count[x]
    if count_x == 1:
        del count[x]
    else:
        count[x] = count_x - 1
    yield from enum1(prefix, count, total - 1, mode)
    count[x] = count_x
    del prefix[-1]


def enum1(prefix, count, total, mode):
    if total == 0:
        yield tuple(prefix)
        return
    if count[mode] * 2 - 1 >= total and [mode] != prefix[-1:]:
        yield from enum2(prefix, mode, count, total, mode)
    else:
        defect_okay = not prefix or count[prefix[-1]] * 2 > total
        mode = get_mode(count)
        for x in list(count.keys()):
            if defect_okay or [x] != prefix[-1:]:
                yield from enum2(prefix, x, count, total, mode)


def enum(seq):
    count = Counter(seq)
    if count:
        yield from enum1([], count, sum(count.values()), get_mode(count))
    else:
        yield ()


def defects(lst):
    return sum(lst[i - 1] == lst[i] for i in range(1, len(lst)))


def test(lst):
    perms = set(permutations(lst))
    opt = min(map(defects, perms))
    slow = {perm for perm in perms if defects(perm) == opt}
    fast = set(enum(lst))
    print(lst, fast, slow)
    assert slow == fast


for r in range(10000):
    test([randrange(3) for i in range(randrange(6))])

撰写回答