检测序列是否为子序列的倍数(Python)
我有一个由零和一组成的元组,比如:
(1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1)
结果是:
(1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1) == (1, 0, 1, 1) * 3
我想要一个函数 f
,这个函数的作用是:如果 s
是一个非空的零和一的元组,那么 f(s)
会返回一个最短的子元组 r
,使得 s == r * n
,其中 n
是一个正整数。
举个例子,
f( (1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1) ) == (1, 0, 1, 1)
有没有什么简单的方法可以在 Python 中写这个函数 f
?
补充:
我现在使用的简单方法是
def f(s):
for i in range(1,len(s)):
if len(s)%i == 0 and s == s[:i] * (len(s)/i):
return s[:i]
7 个回答
简化Knoothe的解决方案。他的算法是对的,但他的实现方式太复杂了。这种实现的时间复杂度也是O(n)。
因为你的数组只由1和0组成,我用现有的str.find实现(Bayer Moore算法)来实现Knoothe的想法。这样做意外地简单,而且运行速度快得惊人。
def f(s):
s2 = ''.join(map(str, s))
return s[:(s2+s2).index(s2, 1)]
下面这个解决方案的时间复杂度是O(N^2),但它的好处是不会创建任何数据的副本(或者切片),因为它是基于迭代器的。
根据你输入数据的大小,避免复制数据可以显著提高速度,但当然,对于非常大的输入,这种方法的表现不如复杂度更低的算法(比如O(N*logN))那么好。
[这是我解决方案的第二个版本,下面是第一个版本。这个版本更容易理解,更接近于提问者的元组乘法,只是使用了迭代器。]
from itertools import izip, chain, tee
def iter_eq(seq1, seq2):
""" assumes the sequences have the same len """
return all( v1 == v2 for v1, v2 in izip(seq1, seq2) )
def dup_seq(seq, n):
""" returns an iterator which is seq chained to itself n times """
return chain(*tee(seq, n))
def is_reps(arr, slice_size):
if len(arr) % slice_size != 0:
return False
num_slices = len(arr) / slice_size
return iter_eq(arr, dup_seq(arr[:slice_size], num_slices))
s = (1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1)
for i in range(1,len(s)):
if is_reps(s, i):
print i, s[:i]
break
[我最初的解决方案]
from itertools import islice
def is_reps(arr, num_slices):
if len(arr) % num_slices != 0:
return False
slice_size = len(arr) / num_slices
for i in xrange(slice_size):
if len(set( islice(arr, i, None, num_slices) )) > 1:
return False
return True
s = (1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1)
for i in range(1,len(s)):
if is_reps(s, i):
print i, s[:i]
break
你可以通过使用类似的方式来避免调用set()
:
def is_iter_unique(seq):
""" a faster version of testing len(set(seq)) <= 1 """
seen = set()
for x in seq:
seen.add(x)
if len(seen) > 1:
return False
return True
并将这一行替换为:
if len(set( islice(arr, i, None, num_slices) )) > 1:
用:
if not is_iter_unique(islice(arr, i, None, num_slices)):
我认为我有一个O(n)的时间解决方案(实际上是2n+r,n是元组的长度,r是子元组),这个方案不使用后缀树,而是用了一种字符串匹配算法(像KMP,你可以在网上找到现成的)。
我们使用了一个鲜为人知的定理:
If x,y are strings over some alphabet,
then xy = yx if and only if x = z^k and y = z^l for some string z and integers k,l.
我现在要说的是,对于我们的问题,这意味着我们只需要判断给定的元组/列表(或字符串)是否是它自身的循环移位!
要判断一个字符串是否是它自身的循环移位,我们可以把这个字符串和它自己拼接在一起(其实不需要真的拼接,虚拟拼接就可以),然后检查是否有子字符串匹配(和它自己匹配)。
为了证明这一点,假设这个字符串是它自身的循环移位。
那么我们有给定的字符串y = uv = vu。因为uv = vu,所以我们必须有u = z^k和v = z^l,因此y = z^{k+l},这是根据上面的定理得出的。反向证明也很简单。
这里是Python代码。这个方法叫做powercheck。
def powercheck(lst):
count = 0
position = 0
for pos in KnuthMorrisPratt(double(lst), lst):
count += 1
position = pos
if count == 2:
break
return lst[:position]
def double(lst):
for i in range(1,3):
for elem in lst:
yield elem
def main():
print powercheck([1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1])
if __name__ == "__main__":
main()
还有这是我使用的KMP代码(来自David Eppstein)。
# Knuth-Morris-Pratt string matching
# David Eppstein, UC Irvine, 1 Mar 2002
def KnuthMorrisPratt(text, pattern):
'''Yields all starting positions of copies of the pattern in the text.
Calling conventions are similar to string.find, but its arguments can be
lists or iterators, not just strings, it returns all matches, not just
the first one, and it does not need the whole text in memory at once.
Whenever it yields, it will have read the text exactly up to and including
the match that caused the yield.'''
# allow indexing into pattern and protect against change during yield
pattern = list(pattern)
# build table of shift amounts
shifts = [1] * (len(pattern) + 1)
shift = 1
for pos in range(len(pattern)):
while shift <= pos and pattern[pos] != pattern[pos-shift]:
shift += shifts[pos-shift]
shifts[pos+1] = shift
# do the actual search
startPos = 0
matchLen = 0
for c in text:
while matchLen == len(pattern) or \
matchLen >= 0 and pattern[matchLen] != c:
startPos += shifts[matchLen]
matchLen -= shifts[matchLen]
matchLen += 1
if matchLen == len(pattern):
yield startPos
对于你的示例,这个输出是
[1,0,1,1]
正如预期的那样。
我把这个方法和shx2的代码(不是numpy的那个)进行了比较,生成了一个随机的50位字符串,然后复制它使总长度达到100万。输出是(十进制数字是time.time()的输出)
1362988461.75
(50, [1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1])
1362988465.96
50 [1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1]
1362988487.14
上述方法大约花了4秒,而shx2的方法花了大约21秒!
这是计时的代码。(shx2的方法叫做powercheck2)。
def rand_bitstring(n):
rand = random.SystemRandom()
lst = []
for j in range(1, n+1):
r = rand.randint(1,2)
if r == 2:
lst.append(0)
else:
lst.append(1)
return lst
def main():
lst = rand_bitstring(50)*200000
print time.time()
print powercheck(lst)
print time.time()
powercheck2(lst)
print time.time()