python类型提示可以使用tensorflow数据类型吗?

2024-04-18 05:45:45 发布

您现在位置:Python中文网/ 问答频道 /正文

是否可以在Python类型提示中使用Tensorflow数据类型tf.dtypes.DType,例如tf.int32

from typing import (
    Union,
)
import tensorflow as tf
import numpy as np


def f(
    a: Union[tf.int32, tf.float32]  # <----
): 
    return a * 2


def g(a: Union[np.int32, np.float32]):
    return a * 2


def test_a():
    f(tf.cast(1.0, dtype=tf.float32))  # <----
    g(np.float32(1.0))                 # Numpy type has no issue

它会导致下面的错误,不知道这是否可能

python3.8/typing.py:149: in _type_check
    raise TypeError(f"{msg} Got {arg!r:.100}.")
E   TypeError: Union[arg, ...]: each arg must be a type. Got tf.int32.

Tags: importtyping类型returntfdefastype
1条回答
网友
1楼 · 发布于 2024-04-18 05:45:45

我假设您希望您的职能部门接受:

  • tf.float32
  • np.float32
  • float
  • tf.int32
  • np.int32
  • int

并且总是返回,比如说,tf.float32。不完全确定这是否涵盖了您的用例,但我会为您的输入参数设置一个宽泛的类型,并将其转换为您函数中所需的类型

^{}可以与类型注释一起使用,通过减少昂贵的图重传次数来提高性能。例如,即使输入是非张量值,用tf.Tensor注释的参数也会转换为张量

from typing import TYPE_CHECKING
import tensorflow as tf
import numpy as np


@tf.function(experimental_follow_type_hints=True)
def foo(x: tf.Tensor) -> tf.float32:
    if x.dtype == tf.int32:
        x = tf.dtypes.cast(x, tf.float32)
    return x * 2

a = tf.cast(1.0, dtype=tf.float32)
b = tf.cast(1.0, dtype=tf.int32)

c = np.float32(1.0)
d = np.int32(1.0)

e = 1.0
f = 1

for var in [a, b, c, d, e, f]:
    print(f"input: {var},\tinput type: {type(var)},\toutput: {foo(var)}\toutput type: {type(foo(var))}")

if TYPE_CHECKING:
    reveal_locals()

python3 stack66968102.py的输出:

input: 1.0,     input type: <class 'tensorflow.python.framework.ops.EagerTensor'>,      output: 2.0     output dtype: <dtype: 'float32'>
input: 1,       input type: <class 'tensorflow.python.framework.ops.EagerTensor'>,      output: 2.0     output dtype: <dtype: 'float32'>
input: 1.0,     input type: <class 'numpy.float32'>,    output: 2.0     output dtype: <dtype: 'float32'>
input: 1,       input type: <class 'numpy.int32'>,      output: 2.0     output dtype: <dtype: 'float32'>
input: 1.0,     input type: <class 'float'>,    output: 2.0     output dtype: <dtype: 'float32'>
input: 1,       input type: <class 'int'>,      output: 2.0     output dtype: <dtype: 'float32'>

mypy stack66968102.py ignore-missing-imports的输出:

stack66968102.py:27: note: Revealed local types are:
stack66968102.py:27: note:     a: Any
stack66968102.py:27: note:     b: Any
stack66968102.py:27: note:     c: numpy.floating[numpy.typing._32Bit*]
stack66968102.py:27: note:     d: numpy.signedinteger[numpy.typing._32Bit*]
stack66968102.py:27: note:     e: builtins.float
stack66968102.py:27: note:     f: builtins.int
stack66968102.py:27: note:     tf: Any
stack66968102.py:27: note:     var: Any

相关问题 更多 >