我的Python代码出现MemoryError,如何优化?

0 投票
3 回答
739 浏览
提问于 2025-04-18 09:50

我在CodeEval上遇到了一个挑战,内容是:

有一个序列 011212201220200112 ...,它是这样构造的:首先是0,然后重复以下操作:把已经写好的部分右边的数字进行替换,0变成1,1变成2,2变成0。比如:

0 -> 01 -> 0112 -> 01121220 -> ...

你需要创建一个算法,来确定序列中第N个位置上的数字是什么。

输入示例:

0
5
101
25684

输出示例:

0
2
1
0

约束条件:

0 <= N <= 3000000000

我已经写了代码,但由于约束条件,N可能是一个很大的数字,我的代码在 numstring += temp 这行出现了 MemoryError 错误。

这是我的代码:

from string import maketrans

def predict(n):
    n = int(n)
    instring = '012'
    outstring = '120'
    tab = maketrans(instring, outstring)
    numstring = '0'
    temp = numstring
    while len(numstring) <= n:
        temp = temp.translate(tab)
        numstring += temp
        temp = numstring
    print numstring[n]

有没有什么方法可以优化这个?

注意: 输入和输出都是字符串类型,而不是整数。

3 个回答

0

如果你考虑到字符串可能变得多大,我想这个练习的目的就是为了把整个字符串都存储在内存里。

另外,你可能想试着把

numstring += temp
temp = numstring

替换成

temp = "%s%s" % (numstring, temp)

这样做在短期内可能会稍微节省一些内存,尽管这并不能真正解决根本的算法问题。

1

想想看,真的有必要把所有东西都存起来吗?

你从一个字符开始,这个字符可以生成两个字符,然后这两个字符又可以生成四个字符,依此类推。在一个特定的子字符串中,每个字符都是由前一个子字符串中的一个字符决定的。你不需要关注其他的字符,只需要关注那些会影响你答案的字符就可以了。

1

这种问题可以通过递归的思路轻松解决。首先,你要摆脱对字符串和字符串操作的思考。你需要做的转换是 0 -> 11 -> 22 -> 0。这意味着你只是对所有数字加 1,然后对 3 取余(也就是加 1,然后除以 3,最后取余数)。

那么,如何用递归来解决这个问题呢?如果你请求的 N0,那么你就到了字符串的起始位置,答案就是 0。如果请求的是更大的索引,我们该怎么办呢?正如汤姆所指出的,如果你把序列分成两半,第二半的每个字符都与第一半的一个字符有关。如果我们能计算出这个字符的索引,就可以递归地解决这个问题。

这个字符的索引很容易计算。序列本身是无限长的,但是,给定输入 N,你总是可以考虑长度为 2^(ceil(log2(N))) 的前缀,这样你总会有类似的情况:

a b c ... z | A B C ... Z
         middle   ^
                  N 

也就是说,N 在第二半部分。

如果从 N 中减去第一半的长度,就能得到我们想要的字符的索引。字符串的一半长度是 2^(floor(log2(N)))

   a b c ... z A B C ... Z      ~~>      A B C ... Z
              ^    ^                         ^
             2^x   N                        N - 2^x

所以我们需要递归地解决 N - 2^x 的问题(其中 x = floor(log2(N))),然后我们需要应用转换,以得到相对于我们在字符串第二半的结果。

换句话说,解决方案是:

def find_digit(n):
    if n == 0:
        return 0
    # bit_length() gives floor(log2(N))
    length = n.bit_length()
    return (1 + find_digit(n - 2 ** (length-1))) % 3

实际上:

In [22]: find_digit(0)
Out[22]: 0

In [23]: find_digit(5)
Out[23]: 2

In [24]: find_digit(101)
Out[24]: 1

In [25]: find_digit(25684)
Out[25]: 0

In [26]: find_digit(3000000000)
Out[26]: 0

In [27]: find_digit(3000000001)
Out[27]: 1

In [28]: find_digit(3000000002)
Out[28]: 1

In [29]: find_digit(30000000021265672555626541643155343185826435812641)
Out[29]: 2


In [30]: find_digit(30000000021265672555626541643155343185826435812641**10)
Out[30]: 1

请注意,这个解决方案在内存和计算上都很高效,只需进行 log(n) 次操作。然而,它仍然限制在 1000 次递归调用之内。

从递归解决方案中,很容易得到一个非递归的解决方案,这个方案可以处理巨大的输入:

def find_digit(n):
    start = 0
    while n:
        start += 1
        length = n.bit_length()
        n -= 2 ** (length - 1)
    return start % 3

如果你想输出一个字符串,只需将 return 行改为 return str(start % 3)。如果你想接收字符串输入,只需在函数顶部添加 n = int(n)

撰写回答