如何更好地实现不规则列表的字典序所有组合?
今天我遇到了一个需要列出所有可能的锯齿形列表组合的情况。比如说,一个简单的方法是:
for a in [1,2,3]:
for b in [4,5,6,7,8,9]:
for c in [1,2]:
yield (a,b,c)
这个方法可以用,但它对可以使用的列表数量没有通用性。下面是一个更通用的方法:
from numpy import zeros, array, nonzero, max
make_subset = lambda x,y: [x[i][j] for i,j in enumerate(y)]
def combinations(items):
num_items = [len(i) - 1 for i in items]
state = zeros(len(items), dtype=int)
finished = array(num_items, dtype=int)
yield grab_items(items, state)
while True:
if state[-1] != num_items[-1]:
state[-1] += 1
yield make_subset(items, state)
else:
incrementable = nonzero(state != finished)[0]
if not len(incrementable):
raise StopIteration
rightmost = max(incrementable)
state[rightmost] += 1
state[rightmost+1:] = 0
yield make_subset(items, state)
有没有更好的方法推荐,或者对上述方法有什么反对意见吗?
1 个回答
6
简单的方法可以用一种更简洁的方式来表示,叫做生成器表达式:
((a,b,c) for a in [1,2,3] for b in [4,5,6,7,8,9] for c in [1,2])
一般的方法可以用递归函数来写得更简单:
def combinations(*seqs):
if not seqs: return (item for item in ())
first, rest = seqs[0], seqs[1:]
if not rest: return ((item,) for item in first)
return ((item,) + items for item in first for items in combinations(*rest))
下面是一个示例用法:
>>> for pair in combinations('abc', [1,2,3]):
... print pair
...
('a', 1)
('a', 2)
('a', 3)
('b', 1)
('b', 2)
('b', 3)
('c', 1)
('c', 2)
('c', 3)