如何制作服务于输入端接收器的特性

2024-04-20 03:27:14 发布

您现在位置:Python中文网/ 问答频道 /正文

我创建了一个具有张量流-伯特语言模型的二元分类器。下面是link示例代码。我能做预测。现在我想导出这个模型。我不确定我是否正确定义了功能规格。在

导出模型的代码。在

feature_spec = {'x': tf.VarLenFeature(tf.string)}  

def serving_input_receiver_fn():  
  serialized_tf_example = tf.placeholder(dtype=tf.string, shape=[1],name='input_example_tensor')
  receiver_tensors = {'examples': serialized_tf_example}
  features = tf.parse_example(serialized_tf_example, feature_spec)
  return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

# Export the estimator
export_path = f'/content/drive/My Drive/binary_class/bert/export'

estimator.export_saved_model(
    export_path,
    serving_input_receiver_fn=serving_input_receiver_fn)

错误

^{pr2}$

Tags: 代码模型inputstringexampletfexportserialized
1条回答
网友
1楼 · 发布于 2024-04-20 03:27:14

notebook中的create_model函数需要一些参数。这些是将传递给模型的特性。在

通过将serving_input_fn函数更新为following,服务函数可以正常工作。在

更新的代码

def serving_input_fn():
  feature_spec = {
      "input_ids" : tf.FixedLenFeature([MAX_SEQ_LENGTH], tf.int64),
      "input_mask" : tf.FixedLenFeature([MAX_SEQ_LENGTH], tf.int64),
      "segment_ids" : tf.FixedLenFeature([MAX_SEQ_LENGTH], tf.int64),
      "label_ids" :  tf.FixedLenFeature([], tf.int64)

  }
  serialized_tf_example = tf.placeholder(dtype=tf.string, 
                                         shape=[None],
                                         name='input_example_tensor')
  receiver_tensors = {'example': serialized_tf_example}
  features = tf.parse_example(serialized_tf_example, feature_spec)
  return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

相关问题 更多 >