如何在使用Typing和Mypy的泛型中区分基类和派生类
考虑以下代码:
from typing import TypeVar
import dataclasses
@dataclasses.dataclass
class A:
pass
@dataclasses.dataclass
class B(A):
pass
T = TypeVar("T", A, B)
def fun(
x1: T,
x2: T,
) -> int:
if type(x1) != type(x2):
raise TypeError("must be same type!")
if type(x1) == A:
return 5
elif type(x1) == B:
return 10
else:
raise TypeError("Type not handled")
fun(x1=A(), x2=A()) # OK
fun(x1=B(), x2=B()) # OK
fun(x1=B(), x2=A()) # Will throw TypeError, how can I get mypy to say this is an error?
fun(x1=A(), x2=B()) # Will throw TypeError, how can I get mypy to say this is an error?
Mypy在这里没有发现任何问题。它似乎总是把传入的对象当作类型为A
的基类对象来理解。
有没有办法让这个通用类型更加严格,也就是说,它能对确切的类类型敏感?比如,如果x1
是类型B
,那么x2
也必须严格是类型B
?如果x1
是类型A
,那么x2
也必须严格是类型A
?
1 个回答
0
这个问题挺有意思的,最开始我考虑用以下方式来解决:
from typing import overload
import dataclasses
@dataclasses.dataclass
class A:
pass
@dataclasses.dataclass
class B(A):
pass
@overload
def fun(x1: B, x2: B) -> int:
...
@overload
def fun(x1: A, x2: A) -> int:
...
def fun(
x1: A | B,
x2: A | B,
) -> int:
if type(x1) != type(x2):
raise TypeError("must be same type!")
if type(x1) == A:
return 5
elif type(x1) == B:
return 10
else:
raise TypeError("Type not handled")
fun(x1=A(), x2=A())
fun(x1=B(), x2=B())
fun(x1=B(), x2=A())
fun(x1=A(), x2=B())
我一开始以为这可能是TypeVar
工作方式的一个特殊情况,但我发现即使我们在重载中指定必须是A, A
或者B, B
,最后的两行代码仍然不会报错。最后的两行代码只是使用了重载的A, A
,因为A, B
仍然是A, A
的一个子类型。Python完全不区分直接实例和子类型——只要A和B之间有结构上的区别,你可以通过Protocol
来强制执行结构类型。
即使你把函数的参数改成list[A]
,B的类型在这个列表中仍然是有效的,原因就是这样。
如果你想让B继承A的所有属性,并且让它们成为不同的类型,我建议你这样做,把A作为一个隐藏的基类,而把A2暴露给最终用户:
from typing import TypeVar
import dataclasses
@dataclasses.dataclass
class A:
pass
@dataclasses.dataclass
class B(A):
pass
@dataclasses.dataclass
class A2(A):
# Note, this class would be empty in practice as well
pass
T = TypeVar("T", A2, B)
def fun(
x1: T,
x2: T,
) -> int:
if type(x1) != type(x2):
raise TypeError("must be same type!")
if type(x1) == A2:
return 5
elif type(x1) == B:
return 10
else:
raise TypeError("Type not handled")
fun(x1=A2(), x2=A2()) # OK
fun(x1=B(), x2=B()) # OK
fun(x1=B(), x2=A2()) # Will throw TypeError, how can I get mypy to say this is an error?
fun(x1=A2(), x2=B()) # Will throw TypeError, how can I get mypy to say this is an error?
希望这对你有帮助!