在Python中实现复数比较?
我知道,对于复数来说,比较运算符是无法普遍定义的。这就是为什么在 Python 中尝试直接比较复数时会抛出一个 TypeError
异常。我明白这是为什么(请不要偏题去解释为什么两个复数不能比较)。
不过,在这个特定的情况下,我想根据复数的大小来实现复数的比较。换句话说,对于复数 z1
和 z2
,当且仅当 abs(z1) > abs(z2)
时,z1 > z2
成立,其中 abs()
是计算复数大小的函数,就像 numpy.abs()
一样。
我想出了一个解决方案(至少我认为我想到了),如下所示:
import numpy as np
class CustomComplex(complex):
def __lt__(self, other):
return np.abs(self) < np.abs(other)
def __le__(self, other):
return np.abs(self) <= np.abs(other)
def __eq__(self, other):
return np.abs(self) == np.abs(other)
def __ne__(self, other):
return np.abs(self) != np.abs(other)
def __gt__(self, other):
return np.abs(self) > np.abs(other)
def __ge__(self, other):
return np.abs(self) >= np.abs(other)
complex = CustomComplex
这个方法似乎有效,但我有几个问题:
- 这样做可以吗,还是有更好的选择?
- 我希望我的包能够与内置的
complex
数据类型和numpy.complex
无缝兼容。怎么才能优雅地做到这一点,而不重复代码呢?
2 个回答
我会跳过所有可能让这个主意变得糟糕的理由,按照你的要求。
这样做合适吗,还是有更好的选择?
其实不需要用numpy,因为普通的 abs
函数就能处理复数,而且速度更快*。如果你想减少代码量,还有一个很方便的 total_ordering
在 functools
里,适合做简单的比较(不过这可能会慢一些):
from functools import total_ordering
@total_ordering
class CustomComplex(complex):
def __eq__(self, other):
return abs(self) == abs(other)
def __lt__(self, other):
return abs(self) < abs(other)
(这就是你需要的全部代码。)
我希望我的包能够透明地与内置的复数数据类型以及numpy的复数类型一起工作。怎么做才能优雅地实现,而不重复代码呢?
当正确的参数是普通的复数(或任何数字)时,它会自动工作:
>>> CustomComplex(1+7j) < 2+8j
True
但是如果你想使用操作符 <
等,而不是函数,这就是你能做到的最好方式。complex
类型不允许你设置 __lt__
,而且 TypeError 是硬编码的。
如果你想对普通的 complex
数字进行这样的比较,你必须定义并使用自己的比较函数,而不是普通的操作符。或者直接用 abs(a) < abs(b)
,这样清晰明了,也不算太啰嗦。
* 比较内置的 abs
和 numpy.abs
的时间:
>>> timeit.timeit('abs(7+6j)')
0.10257387161254883
>>> timeit.timeit('np.abs(7+6j)', 'import numpy as np')
1.6638610363006592
我怕我会跑题(是的,我完全读过你的帖子 :-))。好吧,Python确实允许你这样比较复杂的数字,因为你可以单独定义所有的运算符,尽管我强烈建议你不要像你那样重新定义 __eq__
:你是在说 1 == -1
!
在我看来,问题就出在这里,迟早会让你感到困扰(或者让使用你这个包的任何人感到困扰):在使用等式和不等式时,普通人(以及大多数Python代码)会做一些简单的假设,比如 -1 != 1
,还有 (a <= b) && (b <= a)
意味着 a == b
。而从纯数学的角度来看,这两个假设不可能同时成立。
另一个经典的假设是 a <= b
等价于 -b <= -a
。但在你的预排序中, a <= b
等价于 -a <= -b
!
话虽如此,我会尝试回答你的两个问题:
- 1:在我看来,这是一种有害的方式(如上所述),但我没有更好的替代方案……
- 2:我认为混入类(mixin)可能是一种优雅的方式来减少代码重复
代码示例(基于你自己的代码,但没有经过广泛测试):
import numpy as np
class ComplexOrder(Object):
def __lt__(self, other):
return np.absolute(self) < np.absolute(other)
# ... keep want you want (including or not eq and ne)
def __ge__(self, other):
return np.absolute(self) >= np.absolute(other)
class OrderedComplex(ComplexOrder, complex):
def __init__(self, real, imag = 0):
complex.__init__(self, real, imag)
class NPOrderedComplex64(ComplexOrder, np.complex64):
def __init__(self, real = 0):
np.complex64.__init__(self, real)