tensorflow:模型未能恢复权重
我正在尝试做文本摘要的实验。我发现了一个GitHub上的代码,它使用了亚马逊食品评论的数据集和Tensorflow 1.1.0来编写代码。
我通过在推理代码上加一个循环来运行这个代码,这样我就可以检查多个摘要,效果很好。
checkpoint = "./best_model.ckpt"
loaded_graph = tf.Graph()
with tf.Session(graph=loaded_graph) as sess:
# Load saved model
loader = tf.train.import_meta_graph(checkpoint + '.meta')
loader.restore(sess, checkpoint)
input_data = loaded_graph.get_tensor_by_name('input:0')
logits = loaded_graph.get_tensor_by_name('predictions:0')
text_length = loaded_graph.get_tensor_by_name('text_length:0')
summary_length = loaded_graph.get_tensor_by_name('summary_length:0')
keep_prob = loaded_graph.get_tensor_by_name('keep_prob:0')
while True:
input_sentence = input()
text = text_to_seq(input_sentence)
#Multiply by batch_size to match the model's input parameters
answer_logits = sess.run(logits, {input_data: [text]*batch_size,
summary_length: [np.random.randint(5,8)],
text_length: [len(text)]*batch_size,
keep_prob: 1.0})[0]
# Remove the padding from the tweet
pad = vocab_to_int["<PAD>"]
print('Original Text:', input_sentence)
print('\nText')
print(' Word Ids: {}'.format([i for i in text]))
print(' Input Words: {}'.format(" ".join([int_to_vocab[i] for i in text])))
print('\nSummary')
print(' Word Ids: {}'.format([i for i in answer_logits if i != pad]))
print(' Response Words: {}'.format(" ".join([int_to_vocab[i] for i in answer_logits if i != pad])))
这里有一段保存模型的代码:
checkpoint = "best_model.ckpt"
with tf.Session(graph=train_graph) as sess:
sess.run(tf.global_variables_initializer())
# train the model
if update_loss <= min(summary_update_loss):
print('New Record!')
stop_early = 0
saver = tf.train.Saver()
saver.save(sess, checkpoint)
完整的代码可以在上面的链接找到。
当我停止模型后重新运行这段代码时,这次只是把训练模型的部分注释掉,结果却完全不行。
为了确认权重是否正确加载,我尝试恢复模型并从中断的地方继续训练。但是损失值非常糟糕,几乎和刚开始训练时的损失值一样。这让我得出结论,权重没有正确加载到模型中。
然后我尝试使用tf.saved_model.builder.SavedModelBuilder
来保存模型,并试图从中断的地方重新训练,但问题依旧。它又给出了像是从头开始训练时的损失值。
相关问题:
- 暂无相关问题
2 个回答
0
每次保存一个tensorflow模型的状态时,会生成五个文件:
- checkpoint(检查点)
- events.out.tfevents(事件输出文件)
- model-xxxx.data.00000-0f-00001(模型数据文件)
- model-xxxx.index(模型索引文件),以及
- model-xxxx.meta(模型元数据文件)
假设我的模型保存在一个叫“saved_models”的文件夹里,以上所有文件都直接放在这个文件夹下。你可以恢复你的模型,继续之前的训练,或者也可以对保存的模型进行测试,看看它在测试数据上的表现。
import tensorflow as tf
checkpoint_dir = "saved_models"
with tf.compat.v1.Session() as sess:
saver = tf.compat.v1.train.Saver()
ckpt = tf.compat.v1.train_get_checkpoint_state(checkpoint_dir)
saver.restore(sess, ckpt.model_checkpoint_path)
saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
print("Model checkpoint has been successfully restored.")
# resume training or evaluate
0
你现在只加载了一个元文件,这个文件里并不包含变量的值。你可以使用下面的代码。
saver_path = 'path to your checkpoint'
checkpoint = tf.train.get_checkpoint_state(saver_path)
input_checkpoint = checkpoint.model_checkpoint_path
saver.restore(session, input_checkpoint)