tensorflow:模型未能恢复权重

1 投票
2 回答
941 浏览
提问于 2025-05-18 21:18

我正在尝试做文本摘要的实验。我发现了一个GitHub上的代码,它使用了亚马逊食品评论的数据集Tensorflow 1.1.0来编写代码。

我通过在推理代码上加一个循环来运行这个代码,这样我就可以检查多个摘要,效果很好。

checkpoint = "./best_model.ckpt"

loaded_graph = tf.Graph()
with tf.Session(graph=loaded_graph) as sess:
    # Load saved model
    loader = tf.train.import_meta_graph(checkpoint + '.meta')
    loader.restore(sess, checkpoint)

    input_data = loaded_graph.get_tensor_by_name('input:0')
    logits = loaded_graph.get_tensor_by_name('predictions:0')
    text_length = loaded_graph.get_tensor_by_name('text_length:0')
    summary_length = loaded_graph.get_tensor_by_name('summary_length:0')
    keep_prob = loaded_graph.get_tensor_by_name('keep_prob:0')

    while True:
        input_sentence = input()
        text = text_to_seq(input_sentence)

        #Multiply by batch_size to match the model's input parameters
        answer_logits = sess.run(logits, {input_data: [text]*batch_size, 
                                      summary_length: [np.random.randint(5,8)], 
                                      text_length: [len(text)]*batch_size,
                                      keep_prob: 1.0})[0] 

        # Remove the padding from the tweet
        pad = vocab_to_int["<PAD>"] 

        print('Original Text:', input_sentence)

        print('\nText')
        print('  Word Ids:    {}'.format([i for i in text]))
        print('  Input Words: {}'.format(" ".join([int_to_vocab[i] for i in text])))

        print('\nSummary')
        print('  Word Ids:       {}'.format([i for i in answer_logits if i != pad]))
        print('  Response Words: {}'.format(" ".join([int_to_vocab[i] for i in answer_logits if i != pad])))

这里有一段保存模型的代码:

checkpoint = "best_model.ckpt" 
with tf.Session(graph=train_graph) as sess:
    sess.run(tf.global_variables_initializer())
    # train the model
    if update_loss <= min(summary_update_loss):
        print('New Record!') 
        stop_early = 0
        saver = tf.train.Saver() 
        saver.save(sess, checkpoint)

完整的代码可以在上面的链接找到。

当我停止模型后重新运行这段代码时,这次只是把训练模型的部分注释掉,结果却完全不行。

为了确认权重是否正确加载,我尝试恢复模型并从中断的地方继续训练。但是损失值非常糟糕,几乎和刚开始训练时的损失值一样。这让我得出结论,权重没有正确加载到模型中。

然后我尝试使用tf.saved_model.builder.SavedModelBuilder来保存模型,并试图从中断的地方重新训练,但问题依旧。它又给出了像是从头开始训练时的损失值。

相关问题:

  • 暂无相关问题
暂无标签

2 个回答

0

每次保存一个tensorflow模型的状态时,会生成五个文件:

  1. checkpoint(检查点)
  2. events.out.tfevents(事件输出文件)
  3. model-xxxx.data.00000-0f-00001(模型数据文件)
  4. model-xxxx.index(模型索引文件),以及
  5. model-xxxx.meta(模型元数据文件)

假设我的模型保存在一个叫“saved_models”的文件夹里,以上所有文件都直接放在这个文件夹下。你可以恢复你的模型,继续之前的训练,或者也可以对保存的模型进行测试,看看它在测试数据上的表现。

import tensorflow as tf

checkpoint_dir = "saved_models"
with tf.compat.v1.Session() as sess:
    saver = tf.compat.v1.train.Saver()
    ckpt = tf.compat.v1.train_get_checkpoint_state(checkpoint_dir)
    saver.restore(sess, ckpt.model_checkpoint_path)
    saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
    print("Model checkpoint has been successfully restored.")
    # resume training or evaluate
0

你现在只加载了一个元文件,这个文件里并不包含变量的值。你可以使用下面的代码。

saver_path = 'path to your checkpoint'
checkpoint = tf.train.get_checkpoint_state(saver_path)
input_checkpoint = checkpoint.model_checkpoint_path
saver.restore(session, input_checkpoint)

撰写回答