如何保存经过训练的模型(估计器)并将其加载回Tensorflow中的数据进行测试?

2024-03-28 20:02:12 发布

您现在位置:Python中文网/ 问答频道 /正文

我有这个片段,为我的模型

import pandas as pd
import tensorflow as tf
from tensorflow.contrib import learn
from tensorflow.contrib.learn.python import SKCompat
#Assume my dataset is using X['train'] as input and y['train'] as output

regressor = SKCompat(learn.Estimator(model_fn=lstm_model(TIMESTEPS, RNN_LAYERS, DENSE_LAYERS),model_dir=LOG_DIR))
validation_monitor = learn.monitors.ValidationMonitor(X['val'], y['val'], every_n_steps=PRINT_STEPS, early_stopping_rounds=1000)
regressor.fit(X['train'], y['train'],
              monitors=[validation_monitor],
              batch_size=BATCH_SIZE,
              steps=TRAINING_STEPS)

#After training this model I want to save it in a folder, so I can use the trained model for implementing in my algorithm to predict the output
#What is the correct format to use here to save my model in a folder called 'saved_model'
regressor.export_savedmodel('/saved_model/')

#I want to import it later in some other code, How can I import it?
#is there any function like import model from file?

如何保存这个估计器?我试着为tf.contrib.学习.Estimator.export_savedmodel,我没有成功吗?感谢帮助。在


Tags: thetoinfromimportmodelismy
1条回答
网友
1楼 · 发布于 2024-03-28 20:02:12

函数export_savedmodel需要服务于_input_receiver_fn的参数,这是一个没有参数的函数,它定义来自模型和预测器的输入。因此,您必须创建自己的服务“输入接收器”,其中模型输入类型与训练脚本中的模型输入相匹配,预测值输入类型与测试脚本中的预测值输入匹配。 另一方面,如果创建自定义模型,则必须定义由函数tf.estimator.出口.PredictOutput,它的输入是一个字典,它定义的名称必须与测试脚本中的预测器输出名称相匹配。在

例如:

训练脚本

def serving_input_receiver_fn():
    serialized_tf_example = tf.placeholder(dtype=tf.string, shape=[None], name='input_tensors')
    receiver_tensors      = {"predictor_inputs": serialized_tf_example}
    feature_spec          = {"words": tf.FixedLenFeature([25],tf.int64)}
    features              = tf.parse_example(serialized_tf_example, feature_spec)
    return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

def estimator_spec_for_softmax_classification(logits, labels, mode):
    predicted_classes = tf.argmax(logits, 1)
    if (mode == tf.estimator.ModeKeys.PREDICT):
        export_outputs = {'predict_output': tf.estimator.export.PredictOutput({"pred_output_classes": predicted_classes, 'probabilities': tf.nn.softmax(logits)})}
        return tf.estimator.EstimatorSpec(mode=mode, predictions={'class': predicted_classes, 'prob': tf.nn.softmax(logits)}, export_outputs=export_outputs) # IMPORTANT!!!

    onehot_labels = tf.one_hot(labels, 31, 1, 0)
    loss          = tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels, logits=logits)
    if (mode == tf.estimator.ModeKeys.TRAIN):
        optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
        train_op  = optimizer.minimize(loss, global_step=tf.train.get_global_step())
        return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

    eval_metric_ops = {'accuracy': tf.metrics.accuracy(labels=labels, predictions=predicted_classes)}
    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)

def model_custom(features, labels, mode):
    bow_column           = tf.feature_column.categorical_column_with_identity("words", num_buckets=1000)
    bow_embedding_column = tf.feature_column.embedding_column(bow_column, dimension=50)   
    bow                  = tf.feature_column.input_layer(features, feature_columns=[bow_embedding_column])
    logits               = tf.layers.dense(bow, 31, activation=None)

    return estimator_spec_for_softmax_classification(logits=logits, labels=labels, mode=mode)

def main():
    # ...
    # preprocess-> features_train_set and labels_train_set
    # ...
    classifier     = tf.estimator.Estimator(model_fn = model_custom)
    train_input_fn = tf.estimator.inputs.numpy_input_fn(x={"words": features_train_set}, y=labels_train_set, batch_size=batch_size_param, num_epochs=None, shuffle=True)
    classifier.train(input_fn=train_input_fn, steps=100)

    full_model_dir = classifier.export_savedmodel(export_dir_base="C:/models/directory_base", serving_input_receiver_fn=serving_input_receiver_fn)

测试脚本

^{pr2}$

(在python3.6.3、Tensorflow 1.4.0中测试代码)

相关问题 更多 >