重载 __eq__ 返回自定义对象

2 投票
1 回答
597 浏览
提问于 2025-04-18 11:12

我正在用Python写一个领域特定语言(DSL),想要重载一些运算符,这样我就可以轻松地写出我的DSL表达式。比如,我想写 Var("a") + Var("b"),然后得到 Add(Var("a"), Var("b")) 这样的表示。为此,我重载了 __add__ 方法,这个方法对我来说效果很好。

不过,我也想重载 __eq__ 方法,达到类似的效果:我希望写 Var("a") == Var("b"),然后得到 Eq(Var("a"), Var("b")) 这样的表示。通过重载 __eq__ 方法并返回一个 Eq 的实例,我实现了这个目标。但是,重载 __eq__ 方法显然会干扰Python的标准行为,比如 Var("b") in [Var("a")] 会返回 True

有没有办法让我实现这个目标,也就是能够写 Var("a") == Var("b") 并得到 Eq(Var("a"), Var("b")),同时还能写 if Var("a") == Var("b"): blablabla 或者把表达式放进内置的容器里等等?

编辑

我尝试实现了 Eq 类的 __bool__ 方法,似乎有效(见下面的代码)。我是不是漏掉了什么,或者这是一个可行的解决方案?

class Expr:
    def __add__(self, other):
        return Add(self, other)

    def __eq__(self, other):
        return Eq(self, other)

    def __repr__(self):
        return str(self)

    def __add__(self, other):
        return Add(self, other)

    def __ne__(self, other):
        return Neq(self, other)

class Var(Expr):
    def __init__(self, name):
        self.name = name

    def __str__(self):
        return "Var(" + str(self.name) + ")"

    def equals(self, other):
        if type(self) is type(other):
            return self.name == other.name
        else:
            return False

    def __hash__(self):
        return 17 + 23 * hash(self.name)

class Add(Expr):
    def __init__(self, left, right):
        self.left = left
        self.right = right

    def __str__(self):
        return "Add(" + str(self.left) + ", " + str(self.right) + ")"

    def equals(self, other):
        if type(self) is type(other):
            return ( ( self.left.equals(other.left) and
                       self.right.equals(other.right) ) or
                     ( self.left.equals(other.right) and
                       self.right.equals(other.left) ) )
        else:
            return False

    def __hash__(self):
        return (17 + 23 * hash("+") +
                23 * 23 * hash(self.left) + 23 * 23 * hash(self.right))

class Eq(Expr):
    def __init__(self, left, right):
        self.left = left
        self.right = right

    def __str__(self):
        return "Eq(" + str(self.left) + ", " + str(self.right) + ")"

    def equals(self, other):
        if type(self) is type(other):
            return ( ( self.left.equals(other.left) and
                       self.right.equals(other.right) ) or
                     ( self.left.equals(other.right) and
                       self.right.equals(other.left) ) )
        else:
            return False

    def __bool__(self):
        return self.left.equals(self.right)

    def __hash__(self):
        return (17 + 23 * hash("==") +
                23 * 23 * hash(self.left) + 23 * 23 * hash(self.right))

class Neq(Expr):
    def __init__(self, left, right):
        self.left = left
        self.right = right

    def __str__(self):
        return "Neq(" + str(self.left) + ", " + str(self.right) + ")"

    def equals(self, other):
        if type(self) is type(other):
            return ( ( not self.left.equals(other.left) or
                       not self.right.equals(other.right) ) and
                     ( not self.left.equals(other.right) or
                       not self.right.equals(other.left) ) )
        else:
            return False

    def __bool__(self):
        return not self.left.equals(self.right)

    def __hash__(self):
        return (17 + 23 * hash("!=") +
                23 * 23 * hash(self.left) + 23 * 23 * hash(self.right))


a = Var("a")
aa = Var("a")
b = Var("b")
c = Var("c")


print("a + b", "=>", a + b)   # a + b => Add(Var(a), Var(b))
print("a == b", "=>", a == b) # a == b => Eq(Var(a), Var(b))
print("a != b", "=>", a != b) # a != b => Neq(Var(a), Var(b))

print("a if a == b else b", "=>", a if a == b else b)
# a if a == b else b => Var(b)
print("a if a == aa else b", "=>", a if a == aa else b)
# a if a == aa else b => Var(a)


l = [a, a+b]
print("l", "=>", l)               # l => [Var(a), Add(Var(a), Var(b))]
print("b in l", "=>", b in l)     # b in l => False
print("a in l", "=>", a in l)     # a in l => True
print("aa in l", "=>", aa in l)   # aa in l => True
print("a+b in l", "=>", a+b in l) # a+b in l => True
print("b+a in l", "=>", b+a in l) # b+a in l => True
print("a+c in l", "=>", a+c in l) # a+c in l => False


if a == b:
    print("a == b is True")
else:
    print("a == b is False")        # a == b is False
if a == aa:
    print("a == aa is True")        # a == aa is True
else:
    print("a == aa is False")

if a != b:
    print("a != b is True")         # a != b is True
else:
    print("a != b is False")
if a != aa:
    print("a != aa is True")
else:
    print("a != aa is False")       # a != aa is False


if a == b or a == aa:
    print("a == b or a == aa is True")   # a == b or a == aa is True
else:
    print("a == b or a == aa is False")
if a == aa and a == b:
    print("a == aa and a == b is True")
else:
    print("a == aa and a == b is False") # a == aa and a == b is False
if not a == aa:
    print("not a == aa is True")
else:
    print("not a == aa is False")        # not a == aa is False
if not a == b:
    print("not a == b is True")          # not a == b is True
else:
    print("not a == b is False")


if a == 3:
    print("a == 3 is True")
else:
    print("a == 3 is False")             # a == 3 is False
if a != 3:
    print("a != 3 is True")              # a != 3 is True
else:
    print("a != 3 is False")
if 3 == a:
    print("3 == a is True")
else:
    print("3 == a is False")             # 3 == a is False
if 3 != a:
    print("3 != a is True")              # 3 != a is True
else:
    print("3 != a is False")


if a == 'a':
    print("a == 'a' is True")
else:
    print("a == 'a' is False")           # a == 'a' is False
if a != 'a':
    print("a != 'a' is True")            # a != 'a' is True
else:
    print("a != 'a' is False")
if 'a' == a:
    print("'a' == a is True")
else:
    print("'a' == a is False")           # 'a' == a is False
if 'a' != a:
    print("'a' != a is True")            # 'a' != a is True
else:
    print("'a' != a is False")


s = {a}
print("s", "=>", s)             # s => {Var(a)}
print("a in s", "=>", a in s)   # a in s => True
print("b in s", "=>", b in s)   # b in s => False
print("aa in s", "=>", aa in s) # aa in s => True

d = {a: 1, b: 2}
print("d", "=>", d)             # d => {Var(b): 2, Var(a): 1}
print("d[a]", "=>", d[a])       # d[a] => 1
print("d[b]", "=>", d[b])       # d[b] => 2
print("c in d", "=>", c in d)   # c in d => False
print("aa in d", "=>", aa in d) # aa in d => True
print("d[aa]", "=>", d[aa])     # d[aa] => 1

1 个回答

2

不可以,你必须选择一种行为。.__eq__()这个方法的使用场景是无法可靠地检测到的。

如果你需要同时具备这两种行为,那你就得使用其他的运算符或者方法来表示你想要的功能。

撰写回答