如何按和的顺序遍历大量整数元组?

3 投票
2 回答
1737 浏览
提问于 2025-04-17 15:49

我正在使用 itertools.combinations() 来处理一些整数的组合。

我想找到一个满足特定条件的元组,这个元组的 和是最小的

def findLowestNiceTuple:
    for tup in itertools.combinations(range(1, 6), 2):
        if niceTuple(tup):
            return tup

这个生成器默认的顺序并不是按照元素的和来排列的。例如:

>>> itertools.combinations(range(1, 6), 2)

生成的结果会是这样的:

[(1, 2), (1, 3), (1, 4), (1, 5), (2, 3), (2, 4), (2, 5), (3, 4), (3, 5), (4, 5)]

你可以看到,(1, 5) 的和比 (2, 3) 的和要大。为了能够提前结束,我需要这些元组按照 ..., (1, 4), (2, 3), (1, 5), ... 这样的顺序排列。

对于组合数量不多的情况,你可以通过使用 sorted() 来解决这个问题:

>>> sorted(itertools.combinations(range(1, 6), 2), key=sum)
[(1, 2), (1, 3), (1, 4), (2, 3), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5), (4, 5)]

但是,sorted() 会把生成器转换成一个完整的列表,这样会占用很多内存。这意味着当组合数量很大时,它的表现就不好了。比如 itertools.combinations(range(1, 600), 400) 这样的操作就很可能会导致 MemoryError

有没有更节省内存的方法来达到想要的结果呢?

附注:我知道完全遍历我提到的最后一个序列会花费很长时间,但我正在寻找的那个元组应该离开头很近。如果我能依靠这个顺序,我就可以像第一个代码片段那样提前结束。

2 个回答

1

你可以通过遍历可能的和的范围来获得一个纯粹的迭代器,这样只需要 O(1) 的空间,而不是生成所有组合。根据和的值,输出产生这个和的数字对的子范围:

def sumPairs(minVal,maxVal):
    for total in range(minVal*2,maxVal*2+1):
        for a in range(max(total-maxVal,minVal),min(total-minVal,maxVal)+1):
            if a >= total-a: continue # to skip permutations
            yield (a,total-a)

输出:

for pair in sumPairs(1,6):
    print(pair,sum(pair))

(1, 2) 3
(1, 3) 4
(1, 4) 5
(2, 3) 5
(1, 5) 6
(2, 4) 6
(1, 6) 7
(2, 5) 7
(3, 4) 7
(2, 6) 8
(3, 5) 8
(3, 6) 9
(4, 5) 9
(4, 6) 10
(5, 6) 11

[编辑] n元组的推广

如果想要超越数字对,同时保持 O(1) 的空间复杂度,就有点复杂了。我成功地将空间复杂度控制在 O(S),其中 S 是元组的大小。所以这个解决方案的内存消耗和数字的范围无关。

同样的策略用于遍历基于期望和的组合,但对于每个和值,可能有多个组合可以得到相同的和。这些组合可以通过从一个基础组合开始生成,这个基础组合是从最小的数字开始,到能够达到这个总和的最大的数字为止。这是给定和的可能值的最广泛分布。所有其他的值组合都是通过逐渐降低最后一个值,同时将较小的值向上调整来补偿最后一个值的减少。

# generate offsets in increasing order (>=)
# to produce a total value
def getOffsets(size,total,maxValue):
    #print(size,total,maxValue)
    if not total: yield [0]*size; return
    if size == 1 and total==maxValue: yield [maxValue]; return
    while total>=0 and size*maxValue>=total:
        for prefix in getOffsets(size-1,total-maxValue,maxValue):
            yield prefix + [maxValue]
        maxValue -= 1

# generate all combinations of a range of values
# that produce a given total
def comboOfSum(total,size,minValue,maxValue):
    if size == 1: yield (total,); return
    base        = list(range(minValue,minValue+size)) # start with smallest(s)
    base[-1]    = min(total-sum(base[:-1]),maxValue)  # end with largest
    maxOffset   = base[-1]-base[-2]-1 # freedom of moving smaller values
    totalOffset = total-sum(base)     # compensate decreasing last
    minLast     = (total + size*(size-1)//2)//size # minimum to reach total
    while base[-1]>base[-2] and base[-1] >= minLast:
        for offsets in getOffsets(size-1,totalOffset,maxOffset):
            yield tuple(b+o for b,o in zip(base,offsets+[0])) # apply offsets
        base[-1]    -= 1 # decrease last value
        totalOffset += 1 # increase total to compensate for decrease
        maxOffset   -= 1 # decrease small values' freedom of movement

# generate combinations in order of target sum  
def comboBySum(size,minValue,maxValue):
    minTotal = minValue*size + size*(size-1)//2
    maxTotal = maxValue*size - size*(size-1)//2
    for total in range(minTotal,maxTotal+1):
        yield from comboOfSum(total,size,minValue,maxValue)

验证:(与排序后的组合进行比较)

size   = 4
minVal = 10
maxVal = 80

from itertools import combinations

A = list(comboBySum(size,minVal,maxVal))
B = list(sorted(combinations(range(minVal,maxVal+1),size),key=sum))

print("same content:",set(A)==set(B))               # True
print("order by sum:",[*map(sum,A)]==[*map(sum,B)]) # True

输出(小规模):

for combo in comboBySum(2,1,6):print(combo,sum(combo))

(1, 2) 3
(1, 3) 4
(1, 4) 5
(2, 3) 5
(1, 5) 6
(2, 4) 6
(1, 6) 7
(2, 5) 7
(3, 4) 7
(2, 6) 8
(3, 5) 8
(3, 6) 9
(4, 5) 9
(4, 6) 10
(5, 6) 11

输出(大规模):

for i,combo in enumerate(comboBySum(400,1,800)):
    print(*combo[:5],"...",*combo[-5:],"sum =",sum(combo))
    if i>20: break

1 2 3 4 5 ... 396 397 398 399 400 sum = 80200
1 2 3 4 5 ... 396 397 398 399 401 sum = 80201
1 2 3 4 5 ... 396 397 398 399 402 sum = 80202
1 2 3 4 5 ... 396 397 398 400 401 sum = 80202
1 2 3 4 5 ... 396 397 398 399 403 sum = 80203
1 2 3 4 5 ... 396 397 398 400 402 sum = 80203
1 2 3 4 5 ... 396 397 399 400 401 sum = 80203
1 2 3 4 5 ... 396 397 398 399 404 sum = 80204
1 2 3 4 5 ... 396 397 398 400 403 sum = 80204
1 2 3 4 5 ... 396 397 398 401 402 sum = 80204
1 2 3 4 5 ... 396 397 399 400 402 sum = 80204
1 2 3 4 5 ... 396 398 399 400 401 sum = 80204
1 2 3 4 5 ... 396 397 398 399 405 sum = 80205
1 2 3 4 5 ... 396 397 398 400 404 sum = 80205
1 2 3 4 5 ... 396 397 398 401 403 sum = 80205
1 2 3 4 5 ... 396 397 399 400 403 sum = 80205
1 2 3 4 5 ... 396 397 399 401 402 sum = 80205
1 2 3 4 5 ... 396 398 399 400 402 sum = 80205
1 2 3 4 5 ... 397 398 399 400 401 sum = 80205
1 2 3 4 5 ... 396 397 398 399 406 sum = 80206
1 2 3 4 5 ... 396 397 398 400 405 sum = 80206
1 2 3 4 5 ... 396 397 398 401 404 sum = 80206

输出(大数字范围):

for i,combo in enumerate(comboBySum(20,12345,1000000)):
    print(*combo[:5],"...",*combo[-5:],"sum =",sum(combo))
    if i>20: break

12345 12346 12347 12348 12349 ... 12360 12361 12362 12363 12364 sum = 247090
12345 12346 12347 12348 12349 ... 12360 12361 12362 12363 12365 sum = 247091
12345 12346 12347 12348 12349 ... 12360 12361 12362 12363 12366 sum = 247092
12345 12346 12347 12348 12349 ... 12360 12361 12362 12364 12365 sum = 247092
12345 12346 12347 12348 12349 ... 12360 12361 12362 12363 12367 sum = 247093
12345 12346 12347 12348 12349 ... 12360 12361 12362 12364 12366 sum = 247093
12345 12346 12347 12348 12349 ... 12360 12361 12363 12364 12365 sum = 247093
12345 12346 12347 12348 12349 ... 12360 12361 12362 12363 12368 sum = 247094
12345 12346 12347 12348 12349 ... 12360 12361 12362 12364 12367 sum = 247094
12345 12346 12347 12348 12349 ... 12360 12361 12362 12365 12366 sum = 247094
12345 12346 12347 12348 12349 ... 12360 12361 12363 12364 12366 sum = 247094
12345 12346 12347 12348 12349 ... 12360 12362 12363 12364 12365 sum = 247094
12345 12346 12347 12348 12349 ... 12360 12361 12362 12363 12369 sum = 247095
12345 12346 12347 12348 12349 ... 12360 12361 12362 12364 12368 sum = 247095
12345 12346 12347 12348 12349 ... 12360 12361 12362 12365 12367 sum = 247095
12345 12346 12347 12348 12349 ... 12360 12361 12363 12364 12367 sum = 247095
12345 12346 12347 12348 12349 ... 12360 12361 12363 12365 12366 sum = 247095
12345 12346 12347 12348 12349 ... 12360 12362 12363 12364 12366 sum = 247095
12345 12346 12347 12348 12349 ... 12361 12362 12363 12364 12365 sum = 247095
12345 12346 12347 12348 12349 ... 12360 12361 12362 12363 12370 sum = 247096
12345 12346 12347 12348 12349 ... 12360 12361 12362 12364 12369 sum = 247096
12345 12346 12347 12348 12349 ... 12360 12361 12362 12365 12368 sum = 247096
3

下面是我解决这个问题的方法,使用一个递归函数来找到所有加起来等于给定值的组合:

def ordered_combinations(pop, n):
    pop = sorted(pop)

    for s in range(sum(pop[:n]), sum(pop[-n:])+1):
        yield from get_sums(pop, s, n)

def get_sums(pop, s, n):
    if n == 1:
        if s in pop:
            yield [s]
        return

    for i, v in enumerate(pop):
        if sum(pop[i:i+n]) > s:
            return
        for rest in get_sums(pop[i+1:], s-v, n-1):
            rest.append(v)
            yield rest

这是它输出的一个例子:

>>> for c in ordered_combinations(range(1, 8), 4):
    print(c, sum(c))


[4, 3, 2, 1] 10
[5, 3, 2, 1] 11
[6, 3, 2, 1] 12
[5, 4, 2, 1] 12
[7, 3, 2, 1] 13
[6, 4, 2, 1] 13
[5, 4, 3, 1] 13
[7, 4, 2, 1] 14
[6, 5, 2, 1] 14
[6, 4, 3, 1] 14
[5, 4, 3, 2] 14
[7, 5, 2, 1] 15
[7, 4, 3, 1] 15
[6, 5, 3, 1] 15
[6, 4, 3, 2] 15
[7, 6, 2, 1] 16
[7, 5, 3, 1] 16
[6, 5, 4, 1] 16
[7, 4, 3, 2] 16
[6, 5, 3, 2] 16
[7, 6, 3, 1] 17
[7, 5, 4, 1] 17
[7, 5, 3, 2] 17
[6, 5, 4, 2] 17
[7, 6, 4, 1] 18
[7, 6, 3, 2] 18
[7, 5, 4, 2] 18
[6, 5, 4, 3] 18
[7, 6, 5, 1] 19
[7, 6, 4, 2] 19
[7, 5, 4, 3] 19
[7, 6, 5, 2] 20
[7, 6, 4, 3] 20
[7, 6, 5, 3] 21
[7, 6, 5, 4] 22

这些组合总是先给出最大的值,这是因为我在构建列表时是把小的值加到最后,而不是放到前面。如果你想让它们从小到大排列,可以把 rest.append(v); yield rest 这两行改成 yield [v]+rest

这段代码使用了 yield from 这种语法,这是在 Python 3.3 中引入的。如果你使用的是早期版本,不支持这个语法,可以用下面的等效代码:

for v in get_sums(pop, s, n):
    yield v

这段代码甚至可以处理你提到的极端情况,比如从800个成员中取400个组合。以下是计算的前二十个结果(只显示它们最大的10个值,因为其余的都是390到1),以及它们的和:

>>> for i, v in enumerate(ordered_combinations(range(1, 800), 400)):
    if i >= 20:
        break
    print(v[:10], sum(v))


[400, 399, 398, 397, 396, 395, 394, 393, 392, 391] 80200
[401, 399, 398, 397, 396, 395, 394, 393, 392, 391] 80201
[402, 399, 398, 397, 396, 395, 394, 393, 392, 391] 80202
[401, 400, 398, 397, 396, 395, 394, 393, 392, 391] 80202
[403, 399, 398, 397, 396, 395, 394, 393, 392, 391] 80203
[402, 400, 398, 397, 396, 395, 394, 393, 392, 391] 80203
[401, 400, 399, 397, 396, 395, 394, 393, 392, 391] 80203
[404, 399, 398, 397, 396, 395, 394, 393, 392, 391] 80204
[403, 400, 398, 397, 396, 395, 394, 393, 392, 391] 80204
[402, 401, 398, 397, 396, 395, 394, 393, 392, 391] 80204
[402, 400, 399, 397, 396, 395, 394, 393, 392, 391] 80204
[401, 400, 399, 398, 396, 395, 394, 393, 392, 391] 80204
[405, 399, 398, 397, 396, 395, 394, 393, 392, 391] 80205
[404, 400, 398, 397, 396, 395, 394, 393, 392, 391] 80205
[403, 401, 398, 397, 396, 395, 394, 393, 392, 391] 80205
[403, 400, 399, 397, 396, 395, 394, 393, 392, 391] 80205
[402, 401, 399, 397, 396, 395, 394, 393, 392, 391] 80205
[402, 400, 399, 398, 396, 395, 394, 393, 392, 391] 80205
[401, 400, 399, 398, 397, 395, 394, 393, 392, 391] 80205
[406, 399, 398, 397, 396, 395, 394, 393, 392, 391] 80206

因为这是递归的,所以如果你请求1000个组合,这段代码可能会失败(这是因为 Python 默认的递归限制)。如果需要的话,你可以用 sys.setrecursionlimit 来修改这个限制。

如果你处理的范围非常大,代码可能还会有内存问题,因为 get_sums 在递归步骤中会切片(也就是复制)这个范围。如果你只打算使用 range,可以通过从 ordered_combinations 中删除 pop = sorted(pop) 这一行来解决内存问题,因为 Python 3 的 range 对象可以高效地切片(比如 range(1,100)[10:] 就是 range(11,100))。

撰写回答