Python:优化树评估器

2 投票
2 回答
574 浏览
提问于 2025-04-15 18:41

我知道树是一种研究得很透彻的数据结构。

我正在写一个程序,它可以随机生成很多表达式树,然后根据某个适应度属性进行排序和选择。

我有一个叫做 MakeTreeInOrder() 的类,它可以把树转换成一个字符串,这个字符串可以被 'eval' 函数计算。

但是这个函数被调用的次数非常多,所以我希望能优化它的运行时间。

下面是一个构建树的函数,它会将连续的数字相加,用来做测试。

我在想,是否有更高效的方法来计算树结构中的表达式。我觉得这个问题应该有人研究过,

因为这个用得还挺多的。

import itertools
from collections import namedtuple 

#Further developing Torsten Marek's second suggestion

KS = itertools.count()
Node = namedtuple("Node", ["cargo", "args"]) 

def build_nodes (depth = 5):     
    if (depth <= 0): 
        this_node = Node((str(KS.next())), [None, None])
        return this_node 
    else:
        this_node = Node('+', []) 
        this_node.args.extend( 
          build_nodes(depth = depth - (i + 1))                             
          for i in range(2)) 

        return this_node

以下是我认为可以加速的代码。我希望能得到一些建议。

class MakeTreeInOrder(object):
    def __init__(self, node):
        object.__init__(self)
        self.node = node
        self.str = ''
    def makeit(self, nnode = ''):
        if nnode == '':
            nnode = self.node
        if nnode == None: return
        self.str +='('
        self.makeit(nnode.args[0])
        self.str += nnode.cargo
        self.makeit(nnode.args[1])
        self.str+=')'
        return self.str

def Main():
    this_tree = build_nodes()
    expression_generator = MakeTreeInOrder(this_tree)
    this_expression = expression_generator.makeit()
    print this_expression
    print eval(this_expression)

if __name__ == '__main__':
    rresult = Main()

2 个回答

1

你的例子里引入了numpy和random这两个库,但实际上并没有用到它们。而且里面有个“for i in range(2))”的循环,但里面什么都没有,这显然不是有效的Python代码。

你没有说明'cargo'和节点应该包含什么。看起来'cargo'是一个数字,因为它是从itertools.count().next()中来的。但这样说不太合理,因为你想要的结果应该是一个可以被Python执行的字符串。

如果你只是想一次性评估这个树结构,那么最快的办法就是直接在原地进行评估。不过,由于我没有看到你正在处理的数据示例,所以无法给出具体的例子。

如果你想让这个过程更快,可以考虑从更上游的地方入手。为什么要先生成树结构再去评估它呢?难道不能直接在生成树结构的代码中评估这些组成部分吗?如果你有像“+”和“*”这样的运算符,可以考虑使用operator.add和operator.mul,这样可以直接对数据进行操作,而不需要中间步骤。

==更新==

这个内容是基于Paul Hankin的回答。我所做的就是去掉了中间的树结构,直接评估表达式。

def build_tree2(n):
    if n == 0:
        return random.randrange(100)
    else:
        left = build_tree2(n-1)
        right = build_tree2(n-1)
        return left+right

这个方法的速度大约比Paul的解决方案快5倍。

可能你需要知道最佳解决方案的实际树结构,或者是前k个结果中的N个,前提是k远小于N。如果是这样的话,你可以在找到最佳值后,利用生成结果时的随机数生成器状态来重新生成那些树。例如:

def build_tree3(n, rng=random._inst):
    state = rng.getstate()
    return _build_tree3(n, rng.randrange), state

def _build_tree3(n, randrange):
    if n == 0:
        return randrange(100)
    else:
        left = _build_tree3(n-1, randrange)
        right = _build_tree3(n-1, randrange)
        return left+right

一旦你找到了最佳值,就用这个关键来重新生成树。

# Build Paul's tree data structure given a specific RNG
def build_tree4(n, rng):
    if n == 0:
        return ConstantNode(rng.randrange(100))
    else:
        left = build_tree4(n-1, rng)
        right = build_tree4(n-1, rng)
        return ArithmeticOperatorNode("+", left, right)

# This is a O(n log(n)) way to get the best k.
# An O(k log(k)) time solution is possible.
rng = random.Random()
best_5 = sorted(build_tree3(8, rng) for i in range(10000))[:5]
for value, state in best_5:
    rng.setstate(state)
    tree = build_tree4(8, rng)
    print tree.eval(), "should be", value
    print "  ", str(tree)[:50] + " ..."

这是我运行时的样子。

10793 should be 10793
   ((((((((92 + 75) + (35 + 69)) + ((39 + 79) + (6 +  ...
10814 should be 10814
   ((((((((50 + 63) + (6 + 21)) + ((75 + 98) + (76 +  ...
10892 should be 10892
   ((((((((51 + 25) + (5 + 32)) + ((40 + 71) + (17 +  ...
11070 should be 11070
   ((((((((7 + 83) + (77 + 56)) + ((16 + 29) + (2 + 1 ...
11125 should be 11125
   ((((((((69 + 80) + (11 + 64)) + ((33 + 21) + (95 + ...
3

在这里加入一些面向对象的概念可以让事情变得简单。为你树中的每个元素创建一个Node的子类,并使用一个叫做'eval'的方法来进行计算。

import random

class ArithmeticOperatorNode(object):
    def __init__(self, operator, *args):
        self.operator = operator
        self.children = args
    def eval(self):
        if self.operator == '+':
            return sum(x.eval() for x in self.children)
        assert False, 'Unknown arithmetic operator ' + self.operator
    def __str__(self):
        return '(%s)' % (' ' + self.operator + ' ').join(str(x) for x in self.children)

class ConstantNode(object):
    def __init__(self, constant):
        self.constant = constant
    def eval(self):
        return self.constant
    def __str__(self):
        return str(self.constant)

def build_tree(n):
    if n == 0:
        return ConstantNode(random.randrange(100))
    else:
        left = build_tree(n - 1)
        right = build_tree(n - 1)
        return ArithmeticOperatorNode('+', left, right)

node = build_tree(5)
print node
print node.eval()

要计算整个树,只需在最上层的节点上调用.eval()方法即可。

node = build_tree(5)
print node.eval()

我还添加了一个__str__方法,用来把树转换成字符串,这样你就可以看到它是如何推广到其他树的功能的。目前它只做加法,但希望你能明白如何扩展到其他所有的算术操作。

撰写回答