更复杂情况下Python泛型的类型提示
假设我有一个叫做 Number
的类:
from typing import TypeVar, Generic
T = TypeVar('T')
class Number(Generic[T]):
value: T
这里的 T
是一个类型变量,可以是 int
(整数)、float
(浮点数)、Decimal
(小数)、Fraction
(分数)等等。
现在我想定义一些方法,比如 __add__
。简单来说,我们可以设置参数和返回值的类型都和输入值一样,比如:
class Number(Generic[T]):
value: T
def __add__(self, other: T) -> T: ...
但是,当我们尝试像 int + float
这样的操作时,它们是可以相加的,但上面的类型提示就不管用了。
问题是:类型提示系统怎么才能获取到 T
的信息,告诉我们只有某些类型(而不仅仅是 T
,甚至 T
可能都不行)可以和类型 T
一起操作,比如:
class Number(Generic[T]):
value: T
def __add__(self, other: T_addable) -> T_addition_result: ...
我尝试过使用 Protocol
、covariant
(协变)、contravariant
(逆变)
from typing import TypeVar, Protocol, runtime_checkable
T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
T_contra = TypeVar("T_contra", contravariant=True)
@runtime_checkable
class SupportsAdd(Protocol[T_contra, T_co]):
__slots__ = ()
def __add__(self, x: T_contra) -> T_co: ...
但这无法从 T
中确定出 T_contra
和 T_co
。通常这只能用作 SupportsAdd[Any, Any]
,这并没有帮助。
1 个回答
第一个问题是,T
是未绑定的。它可以是任何东西,比如你可以构造一个 Number[str]
或者 Number[list[set[int]] | dict[tuple[int, int], str]]
。
有两种(如果算上像 AnyStr
这样的约束类型变量,那就是三种)方法来限制它:
- 把类型变量绑定到最具体的超类型。可惜的是,所有数字类型的超类型
numbers.Number
并不被 mypy 支持。 - 利用协议,就像你已经做的那样。
这是我开始的地方:
T = TypeVar('T')
class SupportsAdd(Protocol[T]):
def __add__(self, other: T) -> T: ...
class Number(Generic[T]):
value: SupportsAdd[T]
def __init__(self, value: SupportsAdd[T]) -> None:
self.value = value
def __add__(self, other: T) -> T:
return self.value + other
现在,这几乎能工作,但我需要明确地标注 T
:
n = Number[float](3)
reveal_type(n + 1.0) # float
...而且它并不总是给出想要的结果:
reveal_type(n + 1) # float, runtime type: int
好吧,也许我们可以把 T
分成一个协变部分和一个反变部分,就像你写的那样。但这对我来说并没有奏效。我想到的其他方法也没有成功。
这时我决定退一步思考。这里复杂的根源是什么呢?
关于数字类型的事情是,它们在像 mypy 这样的工具中是特殊处理的。
像 __add__
这样的操作符只被标注为 (Self, Self) -> Self
,而混合使用数字类型的能力是通过一些规则来实现的,比如“int
可以赋值给 float
”,“float
可以赋值给 complex
”,等等。所以如果你执行 3 + 4.2
,mypy 会把它看作是 3
被赋值为一个 float,然后进行 float + float
的运算,而不是 int + float
。
所以问题在于你的 Number
类没有被特殊处理:你不能让 Number[int]
一般情况下可以赋值给 Number[float]
。基本上,你无法构造一个类定义,使得对于一个实例 n
,assert_type(n + 1.0, float)
和 assert_type(n + 1, int)
都能通过。
我希望我错了,但我看不出有什么解决办法。