如何在使用Typing和Mypy的泛型中区分基类和派生类

0 投票
1 回答
32 浏览
提问于 2025-04-13 13:40

考虑以下代码:

from typing import TypeVar
import dataclasses


@dataclasses.dataclass
class A:
    pass


@dataclasses.dataclass
class B(A):
    pass


T = TypeVar("T", A, B)


def fun(
    x1: T,
    x2: T,
) -> int:
    if type(x1) != type(x2):
        raise TypeError("must be same type!")

    if type(x1) == A:
        return 5

    elif type(x1) == B:
        return 10
    else:
        raise TypeError("Type not handled")


fun(x1=A(), x2=A())  # OK
fun(x1=B(), x2=B())  # OK
fun(x1=B(), x2=A())  # Will throw TypeError, how can I get mypy to say this is an error?
fun(x1=A(), x2=B())  # Will throw TypeError, how can I get mypy to say this is an error?

Mypy在这里没有发现任何问题。它似乎总是把传入的对象当作类型为A的基类对象来理解。

有没有办法让这个通用类型更加严格,也就是说,它能对确切的类类型敏感?比如,如果x1是类型B,那么x2也必须严格是类型B?如果x1是类型A,那么x2也必须严格是类型A

1 个回答

0

这个问题挺有意思的,最开始我考虑用以下方式来解决:

from typing import overload
import dataclasses


@dataclasses.dataclass
class A:
    pass


@dataclasses.dataclass
class B(A):
    pass


@overload
def fun(x1: B, x2: B) -> int:
    ...


@overload
def fun(x1: A, x2: A) -> int:
    ...


def fun(
    x1: A | B,
    x2: A | B,
) -> int:
    if type(x1) != type(x2):
        raise TypeError("must be same type!")

    if type(x1) == A:
        return 5

    elif type(x1) == B:
        return 10
    else:
        raise TypeError("Type not handled")


fun(x1=A(), x2=A())
fun(x1=B(), x2=B())
fun(x1=B(), x2=A())
fun(x1=A(), x2=B())

我一开始以为这可能是TypeVar工作方式的一个特殊情况,但我发现即使我们在重载中指定必须是A, A或者B, B,最后的两行代码仍然不会报错。最后的两行代码只是使用了重载的A, A,因为A, B仍然是A, A的一个子类型。Python完全不区分直接实例和子类型——只要A和B之间有结构上的区别,你可以通过Protocol来强制执行结构类型。

即使你把函数的参数改成list[A],B的类型在这个列表中仍然是有效的,原因就是这样。

如果你想让B继承A的所有属性,并且让它们成为不同的类型,我建议你这样做,把A作为一个隐藏的基类,而把A2暴露给最终用户:

from typing import TypeVar
import dataclasses


@dataclasses.dataclass
class A:
    pass


@dataclasses.dataclass
class B(A):
    pass

@dataclasses.dataclass
class A2(A):
    # Note, this class would be empty in practice as well
    pass



T = TypeVar("T", A2, B)

def fun(
    x1: T,
    x2: T,
) -> int:
    if type(x1) != type(x2):
        raise TypeError("must be same type!")

    if type(x1) == A2:
        return 5

    elif type(x1) == B:
        return 10
    else:
        raise TypeError("Type not handled")


fun(x1=A2(), x2=A2())  # OK
fun(x1=B(), x2=B())  # OK
fun(x1=B(), x2=A2())  # Will throw TypeError, how can I get mypy to say this is an error?
fun(x1=A2(), x2=B())  # Will throw TypeError, how can I get mypy to say this is an error?

希望这对你有帮助!

撰写回答