为字典的特定键值对提供类型提示

0 投票
1 回答
105 浏览
提问于 2025-04-14 16:34

假设我在Python中有一个枚举类(Enum class):

class MyEnum(Enum):
    A = "a"
    B = "b"

我有一个函数,这个函数会为每个可能的枚举值(在这个例子中有两个)返回一个特定的类型:假设这两个值都返回一个数据框(DataFrame)。我想给这个函数添加类型提示,为此我使用了TypedDict,像这样:

import pandas as pd
from typing import TypedDict

class ReturnedType(TypedDict):
    MyEnum.A.value: pd.DataFrame
    MyEnum.B.value: pd.DataFrame

然后:

def foo(...) -> ReturnedType

但是显然TypedDict不允许用其他变量来定义字段名称,这导致我的mypy检查失败。

在这种情况下,给这样一个函数添加类型提示的最“Pythonic”的方法是什么呢?

这里有一个最小可重现示例(MWE):

from typing import Dict, TypedDict
from enum import Enum

class MyEnum(Enum):
    A = 'a'
    B = 'b'

class MyClass(TypedDict):
    """The class defines the shape of the dictionary output
    by any RoadCodec"""

    MyEnum.A.value: int
    MyEnum.B.value: float

def foo() -> MyClass:
    res: MyClass = {MyEnum.A.value: 3, MyEnum.B.value: 4.4}

当我运行mypy检查时,我得到了这个错误:

TypedDict定义中的无效语句;预期为 "field_name: field_type" [misc]

另外请注意,我是在Python3.8下运行的,所以StrEnum不可用。

1 个回答

1

很遗憾,这个问题没有简单好看的解决办法。

不过,确实有一种麻烦的方法可以做到。我还没有用mypy测试过,但至少在pylance/pyright/vscode中,这个方法是有效的:

from enum import Enum
from typing import Dict, Union, overload, Literal
class MyEnum(Enum):
    A = "a"
    B = "b"

class ReturnedType(Dict[MyEnum, Union[int, str]]):

    @overload
    def __getitem__(self, __key: Literal[MyEnum.A]) -> int:
        ...
    
    @overload
    def __getitem__(self, __key: Literal[MyEnum.B]) -> str:
        ...

    def __getitem__(self, __key: MyEnum) -> Union[int, str]:
        return super()[__key]
    
    @overload
    def get(self, __key: Literal[MyEnum.A]) -> Union[int, None]:
        ...

    @overload
    def get(self, __key: Literal[MyEnum.A], __default: int) -> int:
        ...

    @overload
    def get(self, __key: Literal[MyEnum.B]) -> Union[str, None]:
        ...

    @overload
    def get(self, __key: Literal[MyEnum.B], __default: str) -> str:
        ...

    def get(self, __key: MyEnum, __default: Union[int, str, None] = None) -> Union[int, str, None]:
        return super().get(__key, __default)

    @overload
    def __setitem__(self, __key: Literal[MyEnum.A], __value: int):
        ...

    @overload
    def __setitem__(self, __key: Literal[MyEnum.B], __value: str):
        ...
    
    def __setitem__(self, __key: MyEnum, __value: Union[int, str]):
        super()[__key] = __value
    
def foo() -> ReturnedType:
    rv = ReturnedType()
    rv[MyEnum.A] = 3
    rv[MyEnum.B] = "hello"
    return rv

d: int = foo()[MyEnum.A]
a: str = foo()[MyEnum.A] # incompatible!

撰写回答