为字典的特定键值对提供类型提示
假设我在Python中有一个枚举类(Enum class):
class MyEnum(Enum):
A = "a"
B = "b"
我有一个函数,这个函数会为每个可能的枚举值(在这个例子中有两个)返回一个特定的类型:假设这两个值都返回一个数据框(DataFrame)。我想给这个函数添加类型提示,为此我使用了TypedDict
,像这样:
import pandas as pd
from typing import TypedDict
class ReturnedType(TypedDict):
MyEnum.A.value: pd.DataFrame
MyEnum.B.value: pd.DataFrame
然后:
def foo(...) -> ReturnedType
但是显然TypedDict
不允许用其他变量来定义字段名称,这导致我的mypy检查失败。
在这种情况下,给这样一个函数添加类型提示的最“Pythonic”的方法是什么呢?
这里有一个最小可重现示例(MWE):
from typing import Dict, TypedDict
from enum import Enum
class MyEnum(Enum):
A = 'a'
B = 'b'
class MyClass(TypedDict):
"""The class defines the shape of the dictionary output
by any RoadCodec"""
MyEnum.A.value: int
MyEnum.B.value: float
def foo() -> MyClass:
res: MyClass = {MyEnum.A.value: 3, MyEnum.B.value: 4.4}
当我运行mypy检查时,我得到了这个错误:
TypedDict定义中的无效语句;预期为 "field_name: field_type" [misc]
另外请注意,我是在Python3.8下运行的,所以StrEnum不可用。
1 个回答
1
很遗憾,这个问题没有简单好看的解决办法。
不过,确实有一种麻烦的方法可以做到。我还没有用mypy测试过,但至少在pylance/pyright/vscode中,这个方法是有效的:
from enum import Enum
from typing import Dict, Union, overload, Literal
class MyEnum(Enum):
A = "a"
B = "b"
class ReturnedType(Dict[MyEnum, Union[int, str]]):
@overload
def __getitem__(self, __key: Literal[MyEnum.A]) -> int:
...
@overload
def __getitem__(self, __key: Literal[MyEnum.B]) -> str:
...
def __getitem__(self, __key: MyEnum) -> Union[int, str]:
return super()[__key]
@overload
def get(self, __key: Literal[MyEnum.A]) -> Union[int, None]:
...
@overload
def get(self, __key: Literal[MyEnum.A], __default: int) -> int:
...
@overload
def get(self, __key: Literal[MyEnum.B]) -> Union[str, None]:
...
@overload
def get(self, __key: Literal[MyEnum.B], __default: str) -> str:
...
def get(self, __key: MyEnum, __default: Union[int, str, None] = None) -> Union[int, str, None]:
return super().get(__key, __default)
@overload
def __setitem__(self, __key: Literal[MyEnum.A], __value: int):
...
@overload
def __setitem__(self, __key: Literal[MyEnum.B], __value: str):
...
def __setitem__(self, __key: MyEnum, __value: Union[int, str]):
super()[__key] = __value
def foo() -> ReturnedType:
rv = ReturnedType()
rv[MyEnum.A] = 3
rv[MyEnum.B] = "hello"
return rv
d: int = foo()[MyEnum.A]
a: str = foo()[MyEnum.A] # incompatible!