如何为真正可选的参数添加类型提示

0 投票
1 回答
31 浏览
提问于 2025-04-12 06:48

我有一个函数,里面有类型提示:

class HDF5DataTypes(Enum):
    SCALAR = "scalar"
    ARRAY = "array"
    UNKNOWN = "unknown"

@overload
def index_hdf5_to_value(file_or_group: Union[h5py.File, h5py.Group], indexes: List[str], expected_output_type: Literal[HDF5DataTypes.SCALAR]) -> np.float_:
    ...
@overload
def index_hdf5_to_value(file_or_group: Union[h5py.File, h5py.Group], indexes: List[str], expected_output_type: Literal[HDF5DataTypes.ARRAY]) -> npt.NDArray:
    ...
@overload
def index_hdf5_to_value(file_or_group: Union[h5py.File, h5py.Group], indexes: List[str], expected_output_type: Literal[HDF5DataTypes.UNKNOWN]) -> Union[npt.NDArray, np.float_]:
    ...
def index_hdf5_to_value(file_or_group: Union[h5py.File, h5py.Group], indexes: List[str], expected_output_type: HDF5DataTypes=HDF5DataTypes.UNKNOWN) -> Union[npt.NDArray, np.float_]:
    '''Given a file or group, returns the output of indexing the file or group with the indexes down until it gets to the dataset, at which point it gives back the value of the dataset (either the scalar or numpy array).
    '''
    dataset = index_hdf5(file_or_group, indexes, h5py.Dataset)
    if len(dataset.shape) == 0:
        if expected_output_type == HDF5DataTypes.ARRAY:
            raise ValueError(f"Expected output to be an array, but it was a scalar")
        return cast(np.float_, dataset[()])
    else:
        if expected_output_type == HDF5DataTypes.SCALAR:
            raise ValueError(f"Expected output to be a scalar, but it was an array")
        return cast(npt.NDArray, dataset[:])

但是,当我用 index_hdf5_to_value(hdf5_file, ["key1", "key2"]) 来调用它时,出现了错误 No overloads for "index_hdf5_to_value" match the provided arguments Argument types: (File, list[str])

简单来说,它不高兴是因为我没有提供第三个参数。

我可以把第三个参数的类型提示为 Optional,但我担心这样会让人觉得可以用 index_hdf5_to_value(hdf5_file, ["key1", "key2"], None) 来调用这个函数,而实际上这是不可以的。

我应该如何正确地给这个函数加类型提示,以告诉用户第三个参数是可选的,但不能设置为 None?(也就是说,它是可选的,但不是 Optional。)

1 个回答

2

尽管名字叫做 Optional,但它并不意味着“可选”。它的意思是“可能是 None”。在参数上加上 Optional 的标记并没有帮助——mypy 仍然不允许你省略这个参数。

你需要提供一个重载,真正允许在没有 expected_output_type 参数的情况下被调用。这可以是一个单独的重载:

@overload
def index_hdf5_to_value(
        file_or_group: Union[h5py.File, h5py.Group],
        indexes: List[str]) -> Union[npt.NDArray, np.float_]:
    ...

或者你可以给 HDF5DataTypes.UNKNOWN 的重载添加一个默认参数值:

@overload
def index_hdf5_to_value(
        file_or_group: Union[h5py.File, h5py.Group],
        indexes: List[str],
        expected_output_type: Literal[HDF5DataTypes.UNKNOWN] = HDF5DataTypes.UNKNOWN
        ) -> Union[npt.NDArray, np.float_]:
    ...

撰写回答