Python AST 转为 Dot 图

11 投票
2 回答
4663 浏览
提问于 2025-04-17 07:28

我正在分析Python代码生成的抽象语法树(AST),想要用更直观的方式来查看这个树,而不是用“ast.dump”这样的文本输出。

理论上,这个AST已经是一个树形结构,所以创建一个图形表示应该不难,但我不太明白该怎么做。

ast.walk好像是用广度优先搜索(BFS)的方法在遍历树,而visitX这些方法我又看不到父节点,似乎也找不到创建图形的方法……

看起来唯一的办法就是自己写一个深度优先搜索(DFS)的遍历函数,这样做合理吗?

2 个回答

8

太棒了,这个方法有效,而且非常简单。

class AstGraphGenerator(object):

    def __init__(self):
        self.graph = defaultdict(lambda: [])

    def __str__(self):
        return str(self.graph)

    def visit(self, node):
        """Visit a node."""
        method = 'visit_' + node.__class__.__name__
        visitor = getattr(self, method, self.generic_visit)
        return visitor(node)

    def generic_visit(self, node):
        """Called if no explicit visitor function exists for a node."""
        for _, value in ast.iter_fields(node):
            if isinstance(value, list):
                for item in value:
                    if isinstance(item, ast.AST):
                        self.visit(item)

            elif isinstance(value, ast.AST):
                self.graph[type(node)].append(type(value))
                self.visit(value)

这个方法和普通的NodeVisitor差不多,不过我用了一个叫defaultdict的东西,里面存储了每个节点的类型。然后我把这个字典传给pygraphviz.AGraph,就得到了我想要的结果。

唯一的问题是,节点的类型信息不够详细,但另一方面,使用ast.dump()又显得太啰嗦了。

最好的办法是获取每个节点的实际源代码,这可能吗?

补充:现在好多了,我在构造函数里也传入了源代码,尽量获取代码行,如果不行的话就只打印类型信息。

class AstGraphGenerator(object):

    def __init__(self, source):
        self.graph = defaultdict(lambda: [])
        self.source = source  # lines of the source code

    def __str__(self):
        return str(self.graph)

    def _getid(self, node):
        try:
            lineno = node.lineno - 1
            return "%s: %s" % (type(node), self.source[lineno].strip())

        except AttributeError:
            return type(node)

    def visit(self, node):
        """Visit a node."""
        method = 'visit_' + node.__class__.__name__
        visitor = getattr(self, method, self.generic_visit)
        return visitor(node)

    def generic_visit(self, node):
        """Called if no explicit visitor function exists for a node."""
        for _, value in ast.iter_fields(node):
            if isinstance(value, list):
                for item in value:
                    if isinstance(item, ast.AST):
                        self.visit(item)

            elif isinstance(value, ast.AST):
                node_source = self._getid(node)
                value_source = self._getid(value)
                self.graph[node_source].append(value_source)
                # self.graph[type(node)].append(type(value))
                self.visit(value)
6

如果你看看 ast.NodeVisitor 这个类,会发现它其实很简单。你可以选择继承这个类,或者根据自己的需要重新实现它的遍历方法。例如,当访问节点时,如果想要保留对父节点的引用,这样做非常简单。你只需要添加一个 visit 方法,并让它也接受父节点作为参数,然后在你自己的 generic_visit 方法中传递这个父节点。

顺便说一下,NodeVisitor.generic_visit 实现了深度优先搜索(DFS),所以你只需要添加传递父节点的功能就可以了。

撰写回答