是否可以在Python类型提示中使用Tensorflow数据类型tf.dtypes.DType,例如tf.int32
from typing import (
Union,
)
import tensorflow as tf
import numpy as np
def f(
a: Union[tf.int32, tf.float32] # <----
):
return a * 2
def g(a: Union[np.int32, np.float32]):
return a * 2
def test_a():
f(tf.cast(1.0, dtype=tf.float32)) # <----
g(np.float32(1.0)) # Numpy type has no issue
它会导致下面的错误,不知道这是否可能
python3.8/typing.py:149: in _type_check
raise TypeError(f"{msg} Got {arg!r:.100}.")
E TypeError: Union[arg, ...]: each arg must be a type. Got tf.int32.
我假设您希望您的职能部门接受:
tf.float32
np.float32
float
tf.int32
np.int32
int
并且总是返回,比如说,
tf.float32
。不完全确定这是否涵盖了您的用例,但我会为您的输入参数设置一个宽泛的类型,并将其转换为您函数中所需的类型^{} 可以与类型注释一起使用,通过减少昂贵的图重传次数来提高性能。例如,即使输入是非张量值,用tf.Tensor注释的参数也会转换为张量
python3 stack66968102.py
的输出:mypy stack66968102.py ignore-missing-imports
的输出:相关问题 更多 >
编程相关推荐