Python - 记忆化与科拉兹序列

6 投票
3 回答
4601 浏览
提问于 2025-04-17 19:34

当我在解决Project Euler的第14题时,我发现可以用一种叫做“备忘录”的方法来加快我的计算速度(我让程序运行了15分钟,结果还是没有返回答案)。问题是,我该怎么实现这个备忘录呢?我尝试过,但总是出现一个叫做keyerror的错误(返回的值无效)。这让我很烦恼,因为我相信我可以把备忘录应用到这个问题上,从而让它运行得更快。

lookup = {}

def countTerms(n):
   arg = n
   count = 1
   while n is not 1:
      count += 1
      if not n%2:
         n /= 2
      else:
         n = (n*3 + 1)
      if n not in lookup:
         lookup[n] = count

   return lookup[n], arg

print max(countTerms(i) for i in range(500001, 1000000, 2)) 

谢谢。

3 个回答

0

这是我对PE14的解决方案:

memo = {1:1}
def get_collatz(n):

if n in memo : return memo[n]

if n % 2 == 0:
    terms = get_collatz(n/2) + 1
else:
    terms = get_collatz(3*n + 1) + 1

memo[n] = terms
return terms

compare = 0
for x in xrange(1, 999999):
if x not in memo:
    ctz = get_collatz(x)
    if ctz > compare:
     compare = ctz
     culprit = x

print culprit
3

记忆化的目的是为了避免重复计算已经算过的部分。在Collatz序列中,后面的序列完全由当前的值决定。所以我们希望尽可能多地检查表格,并在能尽快结束计算时就停止。

def collatz_sequence(start, table={}):  # cheeky trick: store the (mutable) table as a default argument
    """Returns the Collatz sequence for a given starting number"""
    l = []
    n = start

    while n not in l:  # break if we find ourself in a cycle
                       # (don't assume the Collatz conjecture!)
        if n in table:
            l += table[n]
            break
        elif n%2 == 0:
            l.append(n)
            n = n//2
        else:
            l.append(n)
            n = (3*n) + 1

    table.update({n: l[i:] for i, n in enumerate(l) if n not in table})

    return l

这个方法有效吗?我们来监视一下,确保记忆化的元素被使用了:

class NoisyDict(dict):
    def __getitem__(self, item):
        print("getting", item)
        return dict.__getitem__(self, item)

def collatz_sequence(start, table=NoisyDict()):
    # etc



In [26]: collatz_sequence(5)
Out[26]: [5, 16, 8, 4, 2, 1]

In [27]: collatz_sequence(5)
getting 5
Out[27]: [5, 16, 8, 4, 2, 1]

In [28]: collatz_sequence(32)
getting 16
Out[28]: [32, 16, 8, 4, 2, 1]

In [29]: collatz_sequence.__defaults__[0]
Out[29]: 
{1: [1],
 2: [2, 1],
 4: [4, 2, 1],
 5: [5, 16, 8, 4, 2, 1],
 8: [8, 4, 2, 1],
 16: [16, 8, 4, 2, 1],
 32: [32, 16, 8, 4, 2, 1]}

补充:我早就知道可以优化!秘密在于函数中有两个地方(两个返回点),我们知道ltable没有共享的元素。之前我通过测试来避免用已经在table中的元素调用table.update,而这个版本的函数则利用了我们对控制流程的了解,节省了很多时间。

[collatz_sequence(x) for x in range(500001, 1000000)]现在在我电脑上大约需要2秒,而用@welter的版本类似的表达式只需400毫秒。我认为这是因为这两个函数实际上计算的内容不同——我的版本生成整个序列,而@welter的版本只是找出序列的长度。所以我觉得我的实现无法达到同样的速度。

def collatz_sequence(start, table={}):  # cheeky trick: store the (mutable) table as a default argument
    """Returns the Collatz sequence for a given starting number"""
    l = []
    n = start

    while n not in l:  # break if we find ourself in a cycle
                       # (don't assume the Collatz conjecture!)
        if n in table:
            table.update({x: l[i:] for i, x in enumerate(l)})
            return l + table[n]
        elif n%2 == 0:
            l.append(n)
            n = n//2
        else:
            l.append(n)
            n = (3*n) + 1

    table.update({x: l[i:] for i, x in enumerate(l)})
    return l

附言 - 找出错误!

4

还有一种很不错的递归方法来实现这个功能,虽然可能会比poorsod的解决方案慢一些,但它和你最开始的代码更相似,所以你可能更容易理解。

lookup = {}

def countTerms(n):
   if n not in lookup:
      if n == 1:
         lookup[n] = 1
      elif not n % 2:
         lookup[n] = countTerms(n / 2)[0] + 1
      else:
         lookup[n] = countTerms(n*3 + 1)[0] + 1

   return lookup[n], n

print max(countTerms(i) for i in range(500001, 1000000, 2))

撰写回答