Keras model.input不保存名称

2024-04-28 14:02:04 发布

您现在位置:Python中文网/ 问答频道 /正文

我保存了一个Keras模型,定义如下:

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
query.appliance_slot.asr_hyp [(None, None)]            0         
_________________________________________________________________
query.appliance_slot.asr_hyp [(None, None)]            0         
_________________________________________________________________
query.appliance_slot.asr_hyp [(None, None)]            0         
_________________________________________________________________
query.appliance_slot.asr_hyp [(None, None)]            0         
_________________________________________________________________
query.appliance_slot.asr_hyp [(None, None)]            0         
_________________________________________________________________
tf_op_layer_stack (TensorFlo multiple                  0         

...

实际模型定义代码涉及输入层:

            asr_query_input_1 = tf.keras.Input(
                shape=(None,), name="query.appliance_slot.asr_hyp_1", dtype=tf.int32
            )
            asr_query_input_2 = tf.keras.Input(
                shape=(None,), name="query.appliance_slot.asr_hyp_2", dtype=tf.int32
            )
            asr_query_input_3 = tf.keras.Input(
                shape=(None,), name="query.appliance_slot.asr_hyp_3", dtype=tf.int32
            )

但是,当我尝试加载此保存的模型并检查model.inputs时,我看到输入层的名称已更改为通用层:

>>> model.inputs

[<tf.Tensor 'input_1_1:0' shape=(None, None) dtype=int32>,
 <tf.Tensor 'input_2_1:0' shape=(None, None) dtype=int32>,
 <tf.Tensor 'input_3_1:0' shape=(None, None) dtype=int32>,
 <tf.Tensor 'input_4_1:0' shape=(None, None) dtype=int32>,
 <tf.Tensor 'input_5_1:0' shape=(None, None) dtype=int32>,
...

这会导致通过TFRecords通过字典将输入传递到模型中的问题。有人知道为什么会发生这个问题吗


其他信息:

请注意,我已将此问题与模型保存和加载过程隔离开来;在新构建的模型上运行model.inputs时,我看到了正确的行为:

[<tf.Tensor 'query.appliance_slot.asr_hyp_1:0' shape=(None, None) dtype=int32>, <tf.Tensor 'query.appliance_slot.asr_hyp_2:0' shape=(None, None) dtype=int32>, <tf.Tensor 'query.appliance_slot.asr_hyp_3:0' shape=(None, None) dtype=int32>, ...

模型保存和加载代码:

保存:

model.save(full_path, save_format="tf")

装载:

model = tf.keras.models.load_model(model_artifacts_folder)

Tags: 模型noneinputmodeltfasrquerykeras