如何按和的顺序遍历大量整数元组?
我正在使用 itertools.combinations()
来处理一些整数的组合。
我想找到一个满足特定条件的元组,这个元组的 和是最小的:
def findLowestNiceTuple:
for tup in itertools.combinations(range(1, 6), 2):
if niceTuple(tup):
return tup
这个生成器默认的顺序并不是按照元素的和来排列的。例如:
>>> itertools.combinations(range(1, 6), 2)
生成的结果会是这样的:
[(1, 2), (1, 3), (1, 4), (1, 5), (2, 3), (2, 4), (2, 5), (3, 4), (3, 5), (4, 5)]
你可以看到,(1, 5) 的和比 (2, 3) 的和要大。为了能够提前结束,我需要这些元组按照 ..., (1, 4), (2, 3), (1, 5), ...
这样的顺序排列。
对于组合数量不多的情况,你可以通过使用 sorted()
来解决这个问题:
>>> sorted(itertools.combinations(range(1, 6), 2), key=sum)
[(1, 2), (1, 3), (1, 4), (2, 3), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5), (4, 5)]
但是,sorted()
会把生成器转换成一个完整的列表,这样会占用很多内存。这意味着当组合数量很大时,它的表现就不好了。比如 itertools.combinations(range(1, 600), 400)
这样的操作就很可能会导致 MemoryError
。
有没有更节省内存的方法来达到想要的结果呢?
附注:我知道完全遍历我提到的最后一个序列会花费很长时间,但我正在寻找的那个元组应该离开头很近。如果我能依靠这个顺序,我就可以像第一个代码片段那样提前结束。
2 个回答
你可以通过遍历可能的和的范围来获得一个纯粹的迭代器,这样只需要 O(1) 的空间,而不是生成所有组合。根据和的值,输出产生这个和的数字对的子范围:
def sumPairs(minVal,maxVal):
for total in range(minVal*2,maxVal*2+1):
for a in range(max(total-maxVal,minVal),min(total-minVal,maxVal)+1):
if a >= total-a: continue # to skip permutations
yield (a,total-a)
输出:
for pair in sumPairs(1,6):
print(pair,sum(pair))
(1, 2) 3
(1, 3) 4
(1, 4) 5
(2, 3) 5
(1, 5) 6
(2, 4) 6
(1, 6) 7
(2, 5) 7
(3, 4) 7
(2, 6) 8
(3, 5) 8
(3, 6) 9
(4, 5) 9
(4, 6) 10
(5, 6) 11
[编辑] n元组的推广
如果想要超越数字对,同时保持 O(1) 的空间复杂度,就有点复杂了。我成功地将空间复杂度控制在 O(S),其中 S 是元组的大小。所以这个解决方案的内存消耗和数字的范围无关。
同样的策略用于遍历基于期望和的组合,但对于每个和值,可能有多个组合可以得到相同的和。这些组合可以通过从一个基础组合开始生成,这个基础组合是从最小的数字开始,到能够达到这个总和的最大的数字为止。这是给定和的可能值的最广泛分布。所有其他的值组合都是通过逐渐降低最后一个值,同时将较小的值向上调整来补偿最后一个值的减少。
# generate offsets in increasing order (>=)
# to produce a total value
def getOffsets(size,total,maxValue):
#print(size,total,maxValue)
if not total: yield [0]*size; return
if size == 1 and total==maxValue: yield [maxValue]; return
while total>=0 and size*maxValue>=total:
for prefix in getOffsets(size-1,total-maxValue,maxValue):
yield prefix + [maxValue]
maxValue -= 1
# generate all combinations of a range of values
# that produce a given total
def comboOfSum(total,size,minValue,maxValue):
if size == 1: yield (total,); return
base = list(range(minValue,minValue+size)) # start with smallest(s)
base[-1] = min(total-sum(base[:-1]),maxValue) # end with largest
maxOffset = base[-1]-base[-2]-1 # freedom of moving smaller values
totalOffset = total-sum(base) # compensate decreasing last
minLast = (total + size*(size-1)//2)//size # minimum to reach total
while base[-1]>base[-2] and base[-1] >= minLast:
for offsets in getOffsets(size-1,totalOffset,maxOffset):
yield tuple(b+o for b,o in zip(base,offsets+[0])) # apply offsets
base[-1] -= 1 # decrease last value
totalOffset += 1 # increase total to compensate for decrease
maxOffset -= 1 # decrease small values' freedom of movement
# generate combinations in order of target sum
def comboBySum(size,minValue,maxValue):
minTotal = minValue*size + size*(size-1)//2
maxTotal = maxValue*size - size*(size-1)//2
for total in range(minTotal,maxTotal+1):
yield from comboOfSum(total,size,minValue,maxValue)
验证:(与排序后的组合进行比较)
size = 4
minVal = 10
maxVal = 80
from itertools import combinations
A = list(comboBySum(size,minVal,maxVal))
B = list(sorted(combinations(range(minVal,maxVal+1),size),key=sum))
print("same content:",set(A)==set(B)) # True
print("order by sum:",[*map(sum,A)]==[*map(sum,B)]) # True
输出(小规模):
for combo in comboBySum(2,1,6):print(combo,sum(combo))
(1, 2) 3
(1, 3) 4
(1, 4) 5
(2, 3) 5
(1, 5) 6
(2, 4) 6
(1, 6) 7
(2, 5) 7
(3, 4) 7
(2, 6) 8
(3, 5) 8
(3, 6) 9
(4, 5) 9
(4, 6) 10
(5, 6) 11
输出(大规模):
for i,combo in enumerate(comboBySum(400,1,800)):
print(*combo[:5],"...",*combo[-5:],"sum =",sum(combo))
if i>20: break
1 2 3 4 5 ... 396 397 398 399 400 sum = 80200
1 2 3 4 5 ... 396 397 398 399 401 sum = 80201
1 2 3 4 5 ... 396 397 398 399 402 sum = 80202
1 2 3 4 5 ... 396 397 398 400 401 sum = 80202
1 2 3 4 5 ... 396 397 398 399 403 sum = 80203
1 2 3 4 5 ... 396 397 398 400 402 sum = 80203
1 2 3 4 5 ... 396 397 399 400 401 sum = 80203
1 2 3 4 5 ... 396 397 398 399 404 sum = 80204
1 2 3 4 5 ... 396 397 398 400 403 sum = 80204
1 2 3 4 5 ... 396 397 398 401 402 sum = 80204
1 2 3 4 5 ... 396 397 399 400 402 sum = 80204
1 2 3 4 5 ... 396 398 399 400 401 sum = 80204
1 2 3 4 5 ... 396 397 398 399 405 sum = 80205
1 2 3 4 5 ... 396 397 398 400 404 sum = 80205
1 2 3 4 5 ... 396 397 398 401 403 sum = 80205
1 2 3 4 5 ... 396 397 399 400 403 sum = 80205
1 2 3 4 5 ... 396 397 399 401 402 sum = 80205
1 2 3 4 5 ... 396 398 399 400 402 sum = 80205
1 2 3 4 5 ... 397 398 399 400 401 sum = 80205
1 2 3 4 5 ... 396 397 398 399 406 sum = 80206
1 2 3 4 5 ... 396 397 398 400 405 sum = 80206
1 2 3 4 5 ... 396 397 398 401 404 sum = 80206
输出(大数字范围):
for i,combo in enumerate(comboBySum(20,12345,1000000)):
print(*combo[:5],"...",*combo[-5:],"sum =",sum(combo))
if i>20: break
12345 12346 12347 12348 12349 ... 12360 12361 12362 12363 12364 sum = 247090
12345 12346 12347 12348 12349 ... 12360 12361 12362 12363 12365 sum = 247091
12345 12346 12347 12348 12349 ... 12360 12361 12362 12363 12366 sum = 247092
12345 12346 12347 12348 12349 ... 12360 12361 12362 12364 12365 sum = 247092
12345 12346 12347 12348 12349 ... 12360 12361 12362 12363 12367 sum = 247093
12345 12346 12347 12348 12349 ... 12360 12361 12362 12364 12366 sum = 247093
12345 12346 12347 12348 12349 ... 12360 12361 12363 12364 12365 sum = 247093
12345 12346 12347 12348 12349 ... 12360 12361 12362 12363 12368 sum = 247094
12345 12346 12347 12348 12349 ... 12360 12361 12362 12364 12367 sum = 247094
12345 12346 12347 12348 12349 ... 12360 12361 12362 12365 12366 sum = 247094
12345 12346 12347 12348 12349 ... 12360 12361 12363 12364 12366 sum = 247094
12345 12346 12347 12348 12349 ... 12360 12362 12363 12364 12365 sum = 247094
12345 12346 12347 12348 12349 ... 12360 12361 12362 12363 12369 sum = 247095
12345 12346 12347 12348 12349 ... 12360 12361 12362 12364 12368 sum = 247095
12345 12346 12347 12348 12349 ... 12360 12361 12362 12365 12367 sum = 247095
12345 12346 12347 12348 12349 ... 12360 12361 12363 12364 12367 sum = 247095
12345 12346 12347 12348 12349 ... 12360 12361 12363 12365 12366 sum = 247095
12345 12346 12347 12348 12349 ... 12360 12362 12363 12364 12366 sum = 247095
12345 12346 12347 12348 12349 ... 12361 12362 12363 12364 12365 sum = 247095
12345 12346 12347 12348 12349 ... 12360 12361 12362 12363 12370 sum = 247096
12345 12346 12347 12348 12349 ... 12360 12361 12362 12364 12369 sum = 247096
12345 12346 12347 12348 12349 ... 12360 12361 12362 12365 12368 sum = 247096
下面是我解决这个问题的方法,使用一个递归函数来找到所有加起来等于给定值的组合:
def ordered_combinations(pop, n):
pop = sorted(pop)
for s in range(sum(pop[:n]), sum(pop[-n:])+1):
yield from get_sums(pop, s, n)
def get_sums(pop, s, n):
if n == 1:
if s in pop:
yield [s]
return
for i, v in enumerate(pop):
if sum(pop[i:i+n]) > s:
return
for rest in get_sums(pop[i+1:], s-v, n-1):
rest.append(v)
yield rest
这是它输出的一个例子:
>>> for c in ordered_combinations(range(1, 8), 4):
print(c, sum(c))
[4, 3, 2, 1] 10
[5, 3, 2, 1] 11
[6, 3, 2, 1] 12
[5, 4, 2, 1] 12
[7, 3, 2, 1] 13
[6, 4, 2, 1] 13
[5, 4, 3, 1] 13
[7, 4, 2, 1] 14
[6, 5, 2, 1] 14
[6, 4, 3, 1] 14
[5, 4, 3, 2] 14
[7, 5, 2, 1] 15
[7, 4, 3, 1] 15
[6, 5, 3, 1] 15
[6, 4, 3, 2] 15
[7, 6, 2, 1] 16
[7, 5, 3, 1] 16
[6, 5, 4, 1] 16
[7, 4, 3, 2] 16
[6, 5, 3, 2] 16
[7, 6, 3, 1] 17
[7, 5, 4, 1] 17
[7, 5, 3, 2] 17
[6, 5, 4, 2] 17
[7, 6, 4, 1] 18
[7, 6, 3, 2] 18
[7, 5, 4, 2] 18
[6, 5, 4, 3] 18
[7, 6, 5, 1] 19
[7, 6, 4, 2] 19
[7, 5, 4, 3] 19
[7, 6, 5, 2] 20
[7, 6, 4, 3] 20
[7, 6, 5, 3] 21
[7, 6, 5, 4] 22
这些组合总是先给出最大的值,这是因为我在构建列表时是把小的值加到最后,而不是放到前面。如果你想让它们从小到大排列,可以把 rest.append(v); yield rest
这两行改成 yield [v]+rest
。
这段代码使用了 yield from
这种语法,这是在 Python 3.3 中引入的。如果你使用的是早期版本,不支持这个语法,可以用下面的等效代码:
for v in get_sums(pop, s, n):
yield v
这段代码甚至可以处理你提到的极端情况,比如从800个成员中取400个组合。以下是计算的前二十个结果(只显示它们最大的10个值,因为其余的都是390到1),以及它们的和:
>>> for i, v in enumerate(ordered_combinations(range(1, 800), 400)):
if i >= 20:
break
print(v[:10], sum(v))
[400, 399, 398, 397, 396, 395, 394, 393, 392, 391] 80200
[401, 399, 398, 397, 396, 395, 394, 393, 392, 391] 80201
[402, 399, 398, 397, 396, 395, 394, 393, 392, 391] 80202
[401, 400, 398, 397, 396, 395, 394, 393, 392, 391] 80202
[403, 399, 398, 397, 396, 395, 394, 393, 392, 391] 80203
[402, 400, 398, 397, 396, 395, 394, 393, 392, 391] 80203
[401, 400, 399, 397, 396, 395, 394, 393, 392, 391] 80203
[404, 399, 398, 397, 396, 395, 394, 393, 392, 391] 80204
[403, 400, 398, 397, 396, 395, 394, 393, 392, 391] 80204
[402, 401, 398, 397, 396, 395, 394, 393, 392, 391] 80204
[402, 400, 399, 397, 396, 395, 394, 393, 392, 391] 80204
[401, 400, 399, 398, 396, 395, 394, 393, 392, 391] 80204
[405, 399, 398, 397, 396, 395, 394, 393, 392, 391] 80205
[404, 400, 398, 397, 396, 395, 394, 393, 392, 391] 80205
[403, 401, 398, 397, 396, 395, 394, 393, 392, 391] 80205
[403, 400, 399, 397, 396, 395, 394, 393, 392, 391] 80205
[402, 401, 399, 397, 396, 395, 394, 393, 392, 391] 80205
[402, 400, 399, 398, 396, 395, 394, 393, 392, 391] 80205
[401, 400, 399, 398, 397, 395, 394, 393, 392, 391] 80205
[406, 399, 398, 397, 396, 395, 394, 393, 392, 391] 80206
因为这是递归的,所以如果你请求1000个组合,这段代码可能会失败(这是因为 Python 默认的递归限制)。如果需要的话,你可以用 sys.setrecursionlimit
来修改这个限制。
如果你处理的范围非常大,代码可能还会有内存问题,因为 get_sums
在递归步骤中会切片(也就是复制)这个范围。如果你只打算使用 range
,可以通过从 ordered_combinations
中删除 pop = sorted(pop)
这一行来解决内存问题,因为 Python 3 的 range
对象可以高效地切片(比如 range(1,100)[10:]
就是 range(11,100)
)。