在`itertools`模块中找到给定组合的索引
给定从前 n
个自然数中选出的 k
个数字的组合,我需要找出这个组合在通过 itertools.combination(range(1,n),k)
返回的所有组合中的位置。这样做的原因是,我可以用一个 list
来代替 dict
,通过组合来访问与每个组合相关的值。
由于 itertools
生成的组合是有规律的,所以我可以找到这个位置(我也找到了一个不错的算法),但我在寻找一种更快或更自然的方法,可能我还不知道。
顺便说一下,这是我的解决方案:
def find_idx(comb,n):
k=len(comb)
idx=0
last_c=0
for c in comb:
#idx+=sum(nck(n-2-x,k-1) for x in range(c-last_c-1)) # a little faster without nck caching
idx+=nck(n-1,k)-nck(n-c+last_c,k) # more elegant (thanks to Ray), faster with nck caching
n-=c-last_c
k-=1
last_c=c
return idx
其中 nck
返回 n和k的二项式系数。
举个例子:
comb=list(itertools.combinations(range(1,14),6))[654] #pick the 654th combination
find_idx(comb,14) # -> 654
这里还有一个等效但可能更简单的版本(实际上我从下面的版本推导出了前一个)。我把组合 c
中的整数看作二进制数字中1的位置,我在解析0和1时构建了一个二叉树,并发现了解析过程中索引增加的规律:
def find_idx(comb,n):
k=len(comb)
b=bin(sum(1<<(x-1) for x in comb))[2:]
idx=0
for s in b[::-1]:
if s=='0':
idx+=nck(n-2,k-1)
else:
k-=1
n-=1
return idx
3 个回答
看起来你需要更清楚地说明你的任务,或者我理解错了。对我来说,似乎在使用 itertools.combination
进行迭代时,你可以把需要的索引保存到合适的数据结构里。如果你需要所有的索引,我建议使用 dict
(一个 dict
就能满足你所有的需求):
combinationToIdx = {}
for (idx, comb) in enumerate(itertools.combinations(range(1,14),6)):
combinationToIdx[comb] = idx
def findIdx(comb):
return combinationToIdx[comb]
我找到了一些旧代码(虽然已经转换成Python 3的语法),里面有一个叫做 combination_index
的函数,它正好可以满足你的需求:
def fact(n, _f=[1, 1, 2, 6, 24, 120, 720]):
"""Return n!
The “hidden” list _f acts as a cache"""
try:
return _f[n]
except IndexError:
while len(_f) <= n:
_f.append(_f[-1] * len(_f))
return _f[n]
def indexed_combination(n: int, k: int, index: int) -> tuple:
"""Select the 'index'th combination of k over n
Result is a tuple (i | i∈{0…n-1}) of length k
Note that if index ≥ binomial_coefficient(n,k)
then the result is almost always invalid"""
result= []
for item, n in enumerate(range(n, -1, -1)):
pivot= fact(n-1)//fact(k-1)//fact(n-k)
if index < pivot:
result.append(item)
k-= 1
if k <= 0: break
else:
index-= pivot
return tuple(result)
def combination_index(combination: tuple, n: int) -> int:
"""Return the index of combination (length == k)
The combination argument should be a sorted sequence (i | i∈{0…n-1})"""
k= len(combination)
index= 0
item_in_check= 0
n-= 1 # to simplify subsequent calculations
for offset, item in enumerate(combination, 1):
while item_in_check < item:
index+= fact(n-item_in_check)//fact(k-offset)//fact(n+offset-item_in_check-k)
item_in_check+= 1
item_in_check+= 1
return index
def test():
for n in range(1, 11):
for k in range(1, n+1):
max_index= fact(n)//fact(k)//fact(n-k)
for i in range(max_index):
comb= indexed_combination(n, k, i)
i2= combination_index(comb, n)
if i2 != i:
raise RuntimeError("mismatching n:%d k:%d i:%d≠%d" % (n, k, i, i2))
而 indexed_combination
则是做反向操作的。
顺便说一下,我记得我曾经尝试去掉所有那些 fact
的调用(通过适当的增量乘法和除法来替代),但结果代码变得复杂得多,而且实际上也没有更快。如果我用一个预先计算好的阶乘列表来替代 fact
函数,确实可以提高速度,但对于我的使用场景来说,速度差异微乎其微,所以我还是保留了这个版本。
你的解决方案看起来很快。在 find_idx
函数里,你用了两个循环,里面的循环可以通过一个公式来优化:
C(n, k) + C(n-1, k) + ... + C(n-r, k) = C(n+1, k+1) - C(n-r, k+1)
所以,你可以把 sum(nck(n-2-x,k-1) for x in range(c-last_c-1))
替换成 nck(n-1, k) - nck(n-c+last_c, k)
。
我不知道你是怎么实现你的 nck(n, k)
函数的,但它的时间复杂度应该是 O(k)。这里我提供我的实现方式:
from operator import mul
from functools import reduce # In python 3
def nck_safe(n, k):
if k < 0 or n < k: return 0
return reduce(mul, range(n, n-k, -1), 1) // reduce(mul, range(1, k+1), 1)
最后,你的解决方案变成了 O(k^2),而且没有使用递归。这还是挺快的,因为 k
不会太大。
更新
我注意到 nck
的参数是 (n, k)
。这两个值都不会太大。我们可以通过缓存来加速程序。
def nck(n, k, _cache={}):
if (n, k) in _cache: return _cache[n, k]
....
# before returning the result
_cache[n, k] = result
return result
在 Python3 中,可以使用 functools.lru_cache
装饰器来实现:
@functools.lru_cache(maxsize=500)
def nck(n, k):
...