我的Python代码出现MemoryError,如何优化?
我在CodeEval上遇到了一个挑战,内容是:
有一个序列
011212201220200112 ...
,它是这样构造的:首先是0,然后重复以下操作:把已经写好的部分右边的数字进行替换,0变成1,1变成2,2变成0。比如:
0 -> 01 -> 0112 -> 01121220 -> ...
你需要创建一个算法,来确定序列中第N个位置上的数字是什么。
输入示例:
0
5
101
25684
输出示例:
0
2
1
0
约束条件:
0 <= N <= 3000000000
我已经写了代码,但由于约束条件,N可能是一个很大的数字,我的代码在 numstring += temp
这行出现了 MemoryError
错误。
这是我的代码:
from string import maketrans
def predict(n):
n = int(n)
instring = '012'
outstring = '120'
tab = maketrans(instring, outstring)
numstring = '0'
temp = numstring
while len(numstring) <= n:
temp = temp.translate(tab)
numstring += temp
temp = numstring
print numstring[n]
有没有什么方法可以优化这个?
注意: 输入和输出都是字符串类型,而不是整数。
3 个回答
如果你考虑到字符串可能变得多大,我想这个练习的目的就是为了不把整个字符串都存储在内存里。
另外,你可能想试着把
numstring += temp
temp = numstring
替换成
temp = "%s%s" % (numstring, temp)
这样做在短期内可能会稍微节省一些内存,尽管这并不能真正解决根本的算法问题。
想想看,真的有必要把所有东西都存起来吗?
你从一个字符开始,这个字符可以生成两个字符,然后这两个字符又可以生成四个字符,依此类推。在一个特定的子字符串中,每个字符都是由前一个子字符串中的一个字符决定的。你不需要关注其他的字符,只需要关注那些会影响你答案的字符就可以了。
这种问题可以通过递归的思路轻松解决。首先,你要摆脱对字符串和字符串操作的思考。你需要做的转换是 0 -> 1
、1 -> 2
和 2 -> 0
。这意味着你只是对所有数字加 1
,然后对 3
取余(也就是加 1
,然后除以 3
,最后取余数)。
那么,如何用递归来解决这个问题呢?如果你请求的 N
是 0
,那么你就到了字符串的起始位置,答案就是 0
。如果请求的是更大的索引,我们该怎么办呢?正如汤姆所指出的,如果你把序列分成两半,第二半的每个字符都与第一半的一个字符有关。如果我们能计算出这个字符的索引,就可以递归地解决这个问题。
这个字符的索引很容易计算。序列本身是无限长的,但是,给定输入 N
,你总是可以考虑长度为 2^(ceil(log2(N)))
的前缀,这样你总会有类似的情况:
a b c ... z | A B C ... Z
middle ^
N
也就是说,N
在第二半部分。
如果从 N
中减去第一半的长度,就能得到我们想要的字符的索引。字符串的一半长度是 2^(floor(log2(N)))
:
a b c ... z A B C ... Z ~~> A B C ... Z
^ ^ ^
2^x N N - 2^x
所以我们需要递归地解决 N - 2^x
的问题(其中 x = floor(log2(N))
),然后我们需要应用转换,以得到相对于我们在字符串第二半的结果。
换句话说,解决方案是:
def find_digit(n):
if n == 0:
return 0
# bit_length() gives floor(log2(N))
length = n.bit_length()
return (1 + find_digit(n - 2 ** (length-1))) % 3
实际上:
In [22]: find_digit(0)
Out[22]: 0
In [23]: find_digit(5)
Out[23]: 2
In [24]: find_digit(101)
Out[24]: 1
In [25]: find_digit(25684)
Out[25]: 0
In [26]: find_digit(3000000000)
Out[26]: 0
In [27]: find_digit(3000000001)
Out[27]: 1
In [28]: find_digit(3000000002)
Out[28]: 1
In [29]: find_digit(30000000021265672555626541643155343185826435812641)
Out[29]: 2
In [30]: find_digit(30000000021265672555626541643155343185826435812641**10)
Out[30]: 1
请注意,这个解决方案在内存和计算上都很高效,只需进行 log(n)
次操作。然而,它仍然限制在 1000 次递归调用之内。
从递归解决方案中,很容易得到一个非递归的解决方案,这个方案可以处理巨大的输入:
def find_digit(n):
start = 0
while n:
start += 1
length = n.bit_length()
n -= 2 ** (length - 1)
return start % 3
如果你想输出一个字符串,只需将 return
行改为 return str(start % 3)
。如果你想接收字符串输入,只需在函数顶部添加 n = int(n)
。