如何正确使用`@tf.功能`在将keras层/模型子类化时?

2024-04-23 18:22:34 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个自定义的tf.keras.layers.Layer,它只使用TF运算符进行某种位解包(将整数转换为布尔值(0或1个浮点值))。在

class CharUnpack(keras.layers.Layer):

    def __init__(self, name="CharUnpack", *args, **kwargs):
        super(CharUnpack, self).__init__(trainable=False, name=name, *args, **kwargs)
        # Range [7, 6, ..., 0] to bit-shift integers
        self._shifting_range = tf.reshape(
            tf.dtypes.cast(
                tf.range(7, -1, -1, name='shifter_range'),
                tf.uint8,
                name='shifter_cast'),
            (1, 1, 8),
            name='shifter_reshape')
        # Constant value 0b00000001 to use as bitwise and operator
        self._selection_bit = tf.constant(0x01, dtype=tf.uint8, name='and_selection_bit')

    def call(self, inputs):
        return tf.dtypes.cast(
            tf.reshape(
                tf.bitwise.bitwise_and(
                    tf.bitwise.right_shift(
                        tf.expand_dims(inputs, 2),
                        self._shifting_range,
                    ),
                    self._selection_bit,
                ),
                [x if x else -1 for x in self.compute_output_shape(inputs.shape)]
            ),
            tf.float32
        )

    def compute_output_shape(self, input_shape):
        try:
            if len(input_shape) > 1:
                output_shape = tf.TensorShape(tuple(list(input_shape[:-1]) + [input_shape[-1] * 8]))
            else:
                output_shape = tf.TensorShape((input_shape[0] * 8,))
        except TypeError:
            output_shape = input_shape
        return output_shape

    def compute_output_signature(self, input_signature):
        return tf.TensorSpec(self.compute_output_shape(input_signature.shape), tf.float32)

我尝试对这个层进行基准测试,以提高时间性能,如TF guide所示。在

^{pr2}$ ^{3}$

如你所见,我可以得到10倍的加速!!! 因此,我将@tf.function修饰符添加到我的CharUnpack.call方法中:

+    @tf.function
     def call(self, inputs):
         return tf.dtypes.cast(

现在我希望eager和{}调用花费的时间相似,但我没有得到任何改进。在

Function: 0.009667591999459546
Eager: 0.10346330100037449

此外,在第2.1节中,SO answer声明默认情况下模型是图形编译的(这应该是逻辑的),但情况似乎并非如此。。。在

如何正确使用@tf.function修饰符使我的层始终图形化编译?在


Tags: nameselfinputoutputreturntfdefbit