Python - 记忆化与科拉兹序列
当我在解决Project Euler的第14题时,我发现可以用一种叫做“备忘录”的方法来加快我的计算速度(我让程序运行了15分钟,结果还是没有返回答案)。问题是,我该怎么实现这个备忘录呢?我尝试过,但总是出现一个叫做keyerror的错误(返回的值无效)。这让我很烦恼,因为我相信我可以把备忘录应用到这个问题上,从而让它运行得更快。
lookup = {}
def countTerms(n):
arg = n
count = 1
while n is not 1:
count += 1
if not n%2:
n /= 2
else:
n = (n*3 + 1)
if n not in lookup:
lookup[n] = count
return lookup[n], arg
print max(countTerms(i) for i in range(500001, 1000000, 2))
谢谢。
3 个回答
0
这是我对PE14的解决方案:
memo = {1:1}
def get_collatz(n):
if n in memo : return memo[n]
if n % 2 == 0:
terms = get_collatz(n/2) + 1
else:
terms = get_collatz(3*n + 1) + 1
memo[n] = terms
return terms
compare = 0
for x in xrange(1, 999999):
if x not in memo:
ctz = get_collatz(x)
if ctz > compare:
compare = ctz
culprit = x
print culprit
3
记忆化的目的是为了避免重复计算已经算过的部分。在Collatz序列中,后面的序列完全由当前的值决定。所以我们希望尽可能多地检查表格,并在能尽快结束计算时就停止。
def collatz_sequence(start, table={}): # cheeky trick: store the (mutable) table as a default argument
"""Returns the Collatz sequence for a given starting number"""
l = []
n = start
while n not in l: # break if we find ourself in a cycle
# (don't assume the Collatz conjecture!)
if n in table:
l += table[n]
break
elif n%2 == 0:
l.append(n)
n = n//2
else:
l.append(n)
n = (3*n) + 1
table.update({n: l[i:] for i, n in enumerate(l) if n not in table})
return l
这个方法有效吗?我们来监视一下,确保记忆化的元素被使用了:
class NoisyDict(dict):
def __getitem__(self, item):
print("getting", item)
return dict.__getitem__(self, item)
def collatz_sequence(start, table=NoisyDict()):
# etc
In [26]: collatz_sequence(5)
Out[26]: [5, 16, 8, 4, 2, 1]
In [27]: collatz_sequence(5)
getting 5
Out[27]: [5, 16, 8, 4, 2, 1]
In [28]: collatz_sequence(32)
getting 16
Out[28]: [32, 16, 8, 4, 2, 1]
In [29]: collatz_sequence.__defaults__[0]
Out[29]:
{1: [1],
2: [2, 1],
4: [4, 2, 1],
5: [5, 16, 8, 4, 2, 1],
8: [8, 4, 2, 1],
16: [16, 8, 4, 2, 1],
32: [32, 16, 8, 4, 2, 1]}
补充:我早就知道可以优化!秘密在于函数中有两个地方(两个返回点),我们知道l
和table
没有共享的元素。之前我通过测试来避免用已经在table
中的元素调用table.update
,而这个版本的函数则利用了我们对控制流程的了解,节省了很多时间。
[collatz_sequence(x) for x in range(500001, 1000000)]
现在在我电脑上大约需要2秒,而用@welter的版本类似的表达式只需400毫秒。我认为这是因为这两个函数实际上计算的内容不同——我的版本生成整个序列,而@welter的版本只是找出序列的长度。所以我觉得我的实现无法达到同样的速度。
def collatz_sequence(start, table={}): # cheeky trick: store the (mutable) table as a default argument
"""Returns the Collatz sequence for a given starting number"""
l = []
n = start
while n not in l: # break if we find ourself in a cycle
# (don't assume the Collatz conjecture!)
if n in table:
table.update({x: l[i:] for i, x in enumerate(l)})
return l + table[n]
elif n%2 == 0:
l.append(n)
n = n//2
else:
l.append(n)
n = (3*n) + 1
table.update({x: l[i:] for i, x in enumerate(l)})
return l
附言 - 找出错误!
4
还有一种很不错的递归方法来实现这个功能,虽然可能会比poorsod的解决方案慢一些,但它和你最开始的代码更相似,所以你可能更容易理解。
lookup = {}
def countTerms(n):
if n not in lookup:
if n == 1:
lookup[n] = 1
elif not n % 2:
lookup[n] = countTerms(n / 2)[0] + 1
else:
lookup[n] = countTerms(n*3 + 1)[0] + 1
return lookup[n], n
print max(countTerms(i) for i in range(500001, 1000000, 2))