使用父类实现重载子类中的__mul__:导致问题

2 投票
2 回答
3810 浏览
提问于 2025-04-16 07:23

我正在尝试实现类 C__mul__ 方法,而这个类是从类 P 继承来的。类 P 已经有了 __mul__ 方法的实现,但这个实现只适用于 P 类型的元素,也就是说只能用 P() * P() 这样的方式。

所以在 C.__mul__ 中,我想实现一个简单的乘法,当传入的参数是浮点数(float)时就直接进行乘法。如果参数不是浮点数,我想使用 P.__mul__ 方法……但这样会出现问题,因为在 P.__mul__ 中,它是通过 return P(something) 来返回结果的……

所以基本上,经过一些操作后,原本是 C 类型的信息就丢失了。

下面的代码可以更好地解释这个问题。

有没有什么办法解决这个问题呢?

class MyFloat(object):
  def __init__(self, a):
    self.a = a

  def __mul__(self, other):
    return MyFloat(self.a * other.a)

  def __repr__(self):
    return str(self.a)


class MyFloatExt(MyFloat):
  def __init__(self, a):
    MyFloat.__init__(self, a)

  def __add__(self, other):
    return MyFloatExt(self.a + other.a)

  def __mul__(self, other):
    if type(other) == (int, long, float):
      return MyFloatExt(self.a * other)
    else:
      return MyFloat.__mul__(self, other)

a = MyFloatExt(0.5)
b = MyFloatExt(1.5)

c = a + b
print c

d = a * b
print d

e = d * c
print e

print isinstance(e, MyFloat)
f = e * 0.5
print f

2 个回答

2

这里有两个问题

  1. 在你为 MyFloatExt 实现的 __mul__ 方法中,你从来没有检查过 other 是否是 MyFloatExt 的一个实例。

  2. isinstance(e, MyFloat) 这个检查总是会返回真,因为 MyFloatExt 是从 MyFloat 继承来的。

要解决这些问题:

def __mul__(self, other):

    # check if we deal with a MyFloatExt instance
    if isinstance(other, MyFloatExt):
        return MyFloatExt(self.a * other.a)

    if type(other) == (int, long, float):
        return MyFloatExt(self.a * other)

    else:
        return MyFloat.__mul__(self, other)

# do the correct check
print isinstance(e, MyFloatExt)
6

首先,你在 MyFloatExt 里的 __mul__ 方法应该这样进行类型检查:

isinstance(other,(int,long,float))

或者更好一些:

isinstance(other,Number) #from numbers import Number

另外,你还需要把 MyFloat 里的 __mul__ 方法改成这样:

class MyFloat(object):
#...
  def __mul__(self, other):
    return type(self)(self.a * other.a)
#...

这样可以创建你实际类型的实例。

而且,你可以选择调用 super,而不是直接调用 MyFloat.__mul__,这样做是为了让你的类型层次结构更好地演变。

完整代码:

from numbers import Number
class MyFloat(object):
  def __init__(self, a):
    self.a = a

  def __mul__(self, other):
    return type(self)(self.a * other.a)

  def __repr__(self):
    return str(self.a)


class MyFloatExt(MyFloat):
  def __init__(self, a):
    super(MyFloatExt,self).__init__(a)

  def __add__(self, other):
    return type(self)(self.a + other.a)

  def __mul__(self, other):
    if isinstance(other,Number):
      return type(self)(self.a * other)
    else:
      return super(MyFloatExt,self).__mul__(other)


a = MyFloatExt(0.5)
b = MyFloatExt(1.5)

c = a + b
print c

d = a * b
print d


e = d * c
print e

print isinstance(e, MyFloat)

f = e * 0.5
print f

print map(type,[a,b,c,d,e,f]) == [MyFloatExt]*6

撰写回答