为什么我超过了最大递归深度

1 投票
1 回答
79 浏览
提问于 2025-04-14 16:18

我正在尝试制作一个倒计时解题器,灵感来自于一个热门的电视节目。这个解题器的功能是,给定6个数字和一个目标数字,它会不断尝试各种可能的组合,直到找到一个解决方案并返回结果。我并不想要一个超级优化的版本,只需要一个简单的函数,方便我将来加入到更大的程序中(这是我的最终目标)。

但是当我运行我的代码时,出现了“最大递归深度超出”的错误,我不太明白为什么。我原以为它应该只会进行5次递归调用,再加上1到2个像int()这样的函数调用。

这个函数的工作原理是这样的:数字列表作为字符串传入,然后在一个循环中,所有可能的两个数字组合会通过四种运算(记住a + b = b + a)进行计算,并与目标数字进行比较。计算的结果会替换原始列表中的两个数字,从而生成一个少一个元素的新列表。例如:

["7","9","50","25","75","100"] => ["63","50","25","75","100"]

在这里,7和9被63替换了(7 * 9)。这时,函数会用新的列表再次调用自己。这个过程会持续进行,直到列表中只剩下两个元素。如果没有找到解决方案,函数会返回“Q”,这会导致后续的每次调用也返回“Q”。

如果找到了答案,就会返回表示运算的字符串,例如“(7 * 9)”。由于是递归调用,在这个例子中,63是先计算出来的,当返回“(5 + 63)”时,63应该被替换成“(5 + (7 * 9))”。这意味着最终的方法会被返回。

如果我没有解释清楚,请留言,我可以再试一次。

import itertools as i

target=input("What is the target number? ")

def solve(l):
  global count
  if len(l) == 2:
    count+=1
    if int(l[0]) + int(l[1]) == target:
      return "("+l[0] +" + "+l[1]+")"
    count+=1
    if int(l[0]) - int(l[1]) == target:
      return "("+l[0] +" - "+l[1]+")"
    count+=1
    if int(l[0]) * int(l[1]) == target:
      return "("+l[0] +" * "+l[1]+")"
    count+=1
    if int(l[0]) / int(l[1]) == target:
      return "("+l[0] +" / "+l[1]+")"
    count+=1
    if int(l[1]) - int(l[0]) == target:
      return "("+l[0] +" - "+l[1]+")"
    count+=1
    if int(l[1]) / int(l[0]) == target:
      return "("+l[1] +" / "+l[0]+")"
    else:
      return "Q"
  else:
    ct = list(i.combinations(l, 2))
    for item in ct:
      count+=1
      resulta1=int(item[0])+int(item[1])
      if resulta1 == target:
        return "("+item[0] +" + "+item[1]+")"
      count+=1
      resulta2=int(item[0])-int(item[1])
      if resulta2 == target:
        return "("+item[0] +" - "+item[1]+")"
      count+=1
      resulta3=int(item[0])*int(item[1])
      if resulta3 == target:
        return "("+item[0] +" * "+item[1]+")"
      count+=1
      resulta4=int(item[0])/int(item[1])
      if resulta4 == target:
        return "("+item[0] +" / "+item[1]+")"
      count+=1
      resulta5=int(item[1])-int(item[0])
      if resulta5 == target:
        return "("+item[0] +" - "+item[1]+")"
      count+=1
      resulta6=int(item[1])+int(item[0])
      if resulta6 == target:
        return "("+item[1] +" / "+item[0]+")"
      intl = [el for el in l]
      intl.remove(item[0])
      intl.remove(item[1])
      newl=intl
      newl.append(str(resulta1))
      res1=solve(newl)
      if res1 != "Q":
        return res1.replace(str(resulta1),"("+item[0] +" + "+item[1]+")")
      newl=intl
      newl.append(str(resulta2))
      res2=solve(newl)
      if res2 != "Q":
        return res2.replace(str(resulta2),"("+item[0] +" - "+item[1]+")")
      newl=intl
      newl.append(str(resulta3))
      res3=solve(newl)
      if res3 != "Q":
        return res3.replace(str(resulta3),"("+item[0] +" * "+item[1]+")")
      newl=intl
      newl.append(str(resulta4))
      res4=solve(newl)
      if res4 != "Q":
        return res4.replace(str(resulta4),"("+item[0] +" / "+item[1]+")")
      newl=intl
      newl.append(str(resulta5))
      res5=solve(newl)
      if res5 != "Q":
        return res5.replace(str(resulta5),"("+item[1] +" - "+item[0]+")")
      newl=intl
      newl.append(str(resulta6))
      res6=solve(newl)
      if res6 != "Q":
        return res6.replace(str(resulta6),"("+item[1] +" / "+item[0]+")")
      else:
        return "Q"

lis=input("Enter numbers separated by commas: ").split(",")
count=0
solution = solve(lis)
print(count)

另外,我使用的是在线IDE Replit,它的最大递归深度可能和其他环境不同。

1 个回答

0

这里有一些建议:

  • 不要使用关键字 global;相反,应该使用函数的参数。
  • 不要写很多重复的“如果”语句;可以用一个循环来遍历所有选项。
  • 如果一个函数要进行数学运算,就应该把整数当作整数来处理,而不是把它们当作字符串,这样就不需要反复转换成整数了。
  • 可以用一个单独的函数来处理把表达式美观地打印成字符串。
  • 你代码中的递归函数的终止条件是 len(l) == 2;但你可能还需要一个终止条件是 len(l) == 1!比如说,如果你开始时有一个包含三个元素的列表,代码会去掉两个,最后只剩下一个。
  • 既然你已经使用了 itertools.combinations 来找出所有分割数字的方法,就没有必要单独检查 l[0]-l[1]l[1]-l[0];只需相信 combinations 会以任何可能的组合来分割数字。

根据你代码的逻辑,同时应用我自己的建议:

import itertools as i

from operator import add, sub, mul

def div(a, b):
    try:
        if a % b == 0:
            return a // b
        else:
            return float('inf')
    except ZeroDivisionError:
        return float('inf')

rsub = lambda a,b: sub(b,a)
rdiv = lambda a,b: div(b,a)

op_list = (add, sub, mul, div, rsub, rdiv)

def all_splits(seq):
    "all_splits([1,2,3,4,5]) --> ((1, 2), (3, 4, 5))  ((1, 3), (2, 4, 5))  ((1, 4), (2, 3, 5))  ((1, 5), (2, 3, 4))  ((2, 3), (1, 4, 5))  ((2, 4), (1, 3, 5))  ((2, 5), (1, 3, 4))  ((3, 4), (1, 2, 5))  ((3, 5), (1, 2, 4))  ((4, 5), (1, 2, 3))"
    s = set(range(len(seq)))
    for c in i.combinations(s, 2):
        yield tuple(seq[j] for j in c), tuple(seq[j] for j in s.difference(c))

def complement_target(target, op, left):
    '''left op right == target    <--->    right == target revop left'''
    revop = {add: sub, sub: rsub, mul: div, div: rdiv, rsub: add, rdiv: mul}
    return revop[op](target, left)

def solve(target, numbers):
    if len(numbers) == 1:
        if target == numbers[0]:
            yield numbers[0]
    elif len(numbers) == 2:
        a, b = numbers
        for op in op_list:
            if target == op(a, b):
                yield (a, op, b)
    else:
        for (a,b),right_numbers in all_splits(numbers):
            for left_op in op_list:
                left_target = left_op(a,b)
                for op in op_list:
                    right_target = complement_target(target, op, left_target)
                    if op(left_target, right_target) == target: # (x / y) * y ?=? x
                        for right_expr in solve(right_target, right_numbers):
                            yield ((a, left_op, b), op, right_expr)

def stringify(expr):
    opnames = {add:'{a} + {b}',sub:'{a} - {b}',mul:'{a} * {b}',div:'{a} / {b}',rsub:'{b} - {a}',rdiv:'{b} / {a}'}
    if isinstance(expr, int):
        return str(expr)
    elif len(expr) == 3:
        left_expr, op, right_expr = expr
        a=stringify(left_expr)
        b=stringify(right_expr)
        return ''.join(('(', opnames[op].format(a=a,b=b), ')'))
    else:
        raise ValueError(expr)

def main():
    for target,numbers in ((10,(1,2,3,4)), (9,(1,2,3,4,27))):
        for expr in solve(target, numbers):
            print(target, ' == ', stringify(expr))

if __name__ == '__main__':
    main()

输出:

10  ==  ((1 + 2) + (3 + 4))
10  ==  ((3 * 4) - (1 * 2))
10  ==  ((3 * 4) - (2 / 1))
...
10  ==  ((3 + 4) + (1 + 2))
10  ==  ((3 * 4) - (1 * 2))
10  ==  ((3 * 4) - (2 / 1))
9  ==  ((1 + 2) + ((27 - 3) / 4))
9  ==  ((1 + 2) - ((3 - 27) / 4))
9  ==  (((4 - 3) * 27) / (1 + 2))
...
9  ==  ((27 / 3) / ((4 - 1) - 2))
9  ==  ((27 / 3) / ((4 - 2) - 1))
9  ==  ((27 / 3) / ((4 / 2) - 1))

撰写回答