优雅地测试Python AST相等性的方法(非引用或对象身份)

23 投票
4 回答
4437 浏览
提问于 2025-04-16 01:42

这里的术语可能不太准确,但我想说的是,eq?equal? 在 Scheme 语言中的区别,或者 C 语言中 ==strncmp 的区别;在这两种情况下,第一个方法会对两个内容相同但不同的字符串返回 false,而第二个方法则会返回 true。

我现在想要的是后者的操作,适用于 Python 的抽象语法树(AST)。

目前,我的做法是这样的:

import ast
def AST_eq(a, b):
    return ast.dump(a) == ast.dump(b)

这个方法似乎能用,但我总觉得它像是个即将出问题的灾难。有没有人知道更好的方法?

编辑:不幸的是,当我去比较两个 AST 的 __dict__ 时,这个比较默认使用各个元素的 __eq__ 方法。AST 是由其他 AST 组成的树结构,而它们的 __eq__ 方法显然是检查引用是否相同。所以直接用 == 或者 Thomas 提供的链接中的解决方案都不行。(而且,我也不想为每种 AST 节点类型都去创建一个子类来插入这个自定义的 __eq__ 方法。)

4 个回答

3

下面的代码在Python 2和3中都能运行,而且比使用itertools要快:

编辑:警告

显然,这段代码在某些(奇怪的)情况下可能会卡住。因此,我不建议使用它。

def compare_ast(node1, node2):

    if type(node1) != type(node2):
        return False
    elif isinstance(node1, ast.AST):
        for kind, var in vars(node1).items():
            if kind not in ('lineno', 'col_offset', 'ctx'):
                var2 = vars(node2).get(kind)
                if not compare_ast(var, var2):
                    return False
        return True
    elif isinstance(node1, list):
        if len(node1) != len(node2):
            return False
        for i in range(len(node1)):
            if not compare_ast(node1[i], node2[i]):
                return False
        return True
    else:
        return node1 == node2
5

我对@Yorik.sar的回答进行了修改,使其适用于Python 3.9及以上版本:

from itertools import zip_longest
from typing import Union


def compare_ast(node1: Union[ast.expr, list[ast.expr]], node2: Union[ast.expr, list[ast.expr]]) -> bool:
    if type(node1) is not type(node2):
        return False

    if isinstance(node1, ast.AST):
        for k, v in vars(node1).items():
            if k in {"lineno", "end_lineno", "col_offset", "end_col_offset", "ctx"}:
                continue
            if not compare_ast(v, getattr(node2, k)):
                return False
        return True

    elif isinstance(node1, list) and isinstance(node2, list):
        return all(compare_ast(n1, n2) for n1, n2 in zip_longest(node1, node2))
    else:
        return node1 == node2
8

我遇到了同样的问题。我尝试了这样做:首先把抽象语法树(AST)简化成一种更容易理解的表示方式(一个字典的树状结构):

def simplify(node):
    if isinstance(node, ast.AST):
        res = vars(node).copy()
        for k in 'lineno', 'col_offset', 'ctx':
            res.pop(k, None)
        for k, v in res.iteritems():
            res[k] = simplify(v)
        res['__type__'] = type(node).__name__
        return res
    elif isinstance(node, list):
        return map(simplify, node)
    else:
        return node

然后你可以直接比较这些表示方式:

data = open("/usr/lib/python2.7/ast.py").read()
a1 = ast.parse(data)
a2 = ast.parse(data)
print simplify(a1) == simplify(a2)

这样会返回 True

编辑

我刚明白其实不需要创建字典,所以你可以直接这样做:

def compare_ast(node1, node2):
    if type(node1) is not type(node2):
        return False
    if isinstance(node1, ast.AST):
        for k, v in vars(node1).iteritems():
            if k in ('lineno', 'col_offset', 'ctx'):
                continue
            if not compare_ast(v, getattr(node2, k)):
                return False
        return True
    elif isinstance(node1, list):
        return all(itertools.starmap(compare_ast, itertools.izip(node1, node2)))
    else:
        return node1 == node2

撰写回答