使用父类实现重载子类中的__mul__:导致问题
我正在尝试实现类 C
的 __mul__
方法,而这个类是从类 P
继承来的。类 P
已经有了 __mul__
方法的实现,但这个实现只适用于 P
类型的元素,也就是说只能用 P() * P()
这样的方式。
所以在 C.__mul__
中,我想实现一个简单的乘法,当传入的参数是浮点数(float)时就直接进行乘法。如果参数不是浮点数,我想使用 P.__mul__
方法……但这样会出现问题,因为在 P.__mul__
中,它是通过 return P(something)
来返回结果的……
所以基本上,经过一些操作后,原本是 C
类型的信息就丢失了。
下面的代码可以更好地解释这个问题。
有没有什么办法解决这个问题呢?
class MyFloat(object):
def __init__(self, a):
self.a = a
def __mul__(self, other):
return MyFloat(self.a * other.a)
def __repr__(self):
return str(self.a)
class MyFloatExt(MyFloat):
def __init__(self, a):
MyFloat.__init__(self, a)
def __add__(self, other):
return MyFloatExt(self.a + other.a)
def __mul__(self, other):
if type(other) == (int, long, float):
return MyFloatExt(self.a * other)
else:
return MyFloat.__mul__(self, other)
a = MyFloatExt(0.5)
b = MyFloatExt(1.5)
c = a + b
print c
d = a * b
print d
e = d * c
print e
print isinstance(e, MyFloat)
f = e * 0.5
print f
2 个回答
2
这里有两个问题
在你为
MyFloatExt
实现的__mul__
方法中,你从来没有检查过other
是否是MyFloatExt
的一个实例。isinstance(e, MyFloat)
这个检查总是会返回真,因为MyFloatExt
是从MyFloat
继承来的。
要解决这些问题:
def __mul__(self, other):
# check if we deal with a MyFloatExt instance
if isinstance(other, MyFloatExt):
return MyFloatExt(self.a * other.a)
if type(other) == (int, long, float):
return MyFloatExt(self.a * other)
else:
return MyFloat.__mul__(self, other)
# do the correct check
print isinstance(e, MyFloatExt)
6
首先,你在 MyFloatExt
里的 __mul__
方法应该这样进行类型检查:
isinstance(other,(int,long,float))
或者更好一些:
isinstance(other,Number) #from numbers import Number
另外,你还需要把 MyFloat
里的 __mul__
方法改成这样:
class MyFloat(object):
#...
def __mul__(self, other):
return type(self)(self.a * other.a)
#...
这样可以创建你实际类型的实例。
而且,你可以选择调用 super
,而不是直接调用 MyFloat.__mul__
,这样做是为了让你的类型层次结构更好地演变。
完整代码:
from numbers import Number
class MyFloat(object):
def __init__(self, a):
self.a = a
def __mul__(self, other):
return type(self)(self.a * other.a)
def __repr__(self):
return str(self.a)
class MyFloatExt(MyFloat):
def __init__(self, a):
super(MyFloatExt,self).__init__(a)
def __add__(self, other):
return type(self)(self.a + other.a)
def __mul__(self, other):
if isinstance(other,Number):
return type(self)(self.a * other)
else:
return super(MyFloatExt,self).__mul__(other)
a = MyFloatExt(0.5)
b = MyFloatExt(1.5)
c = a + b
print c
d = a * b
print d
e = d * c
print e
print isinstance(e, MyFloat)
f = e * 0.5
print f
print map(type,[a,b,c,d,e,f]) == [MyFloatExt]*6