我有一个使用迭代器训练我的网络的模型;遵循Google现在推荐的新的数据集API管道模型。在
我读取tfrecord文件,向网络提供数据,进行良好的训练,一切进展顺利,我在训练结束时保存了模型,以便以后可以对其进行推理。该规范的简化版本如下:
""" Training and saving """
training_dataset = tf.contrib.data.TFRecordDataset(training_record)
training_dataset = training_dataset.map(ds._path_records_parser)
training_dataset = training_dataset.batch(BATCH_SIZE)
with tf.name_scope("iterators"):
training_iterator = Iterator.from_structure(training_dataset.output_types, training_dataset.output_shapes)
next_training_element = training_iterator.get_next()
training_init_op = training_iterator.make_initializer(training_dataset)
def train(num_epochs):
# compute for the number of epochs
for e in range(1, num_epochs+1):
session.run(training_init_op) #initializing iterator here
while True:
try:
images, labels = session.run(next_training_element)
session.run(optimizer, feed_dict={x: images, y_true: labels})
except tf.errors.OutOfRangeError:
saver_name = './saved_models/ucf-model'
print("Finished Training Epoch {}".format(e))
break
""" Restoring """
# restoring the saved model and its variables
session = tf.Session()
saver = tf.train.import_meta_graph(r'saved_models\ucf-model.meta')
saver.restore(session, tf.train.latest_checkpoint('.\saved_models'))
graph = tf.get_default_graph()
# restoring relevant tensors/ops
accuracy = graph.get_tensor_by_name("accuracy/Mean:0") #the tensor that when evaluated returns the mean accuracy of the batch
testing_iterator = graph.get_operation_by_name("iterators/Iterator") #my iterator used in testing.
next_testing_element = graph.get_operation_by_name("iterators/IteratorGetNext") #the GetNext operator for my iterator
# loading my testing set tfrecords
testing_dataset = tf.contrib.data.TFRecordDataset(testing_record_path)
testing_dataset = testing_dataset.map(ds._path_records_parser, num_threads=4, output_buffer_size=BATCH_SIZE*20)
testing_dataset = testing_dataset.batch(BATCH_SIZE)
testing_init_op = testing_iterator.make_initializer(testing_dataset) #to initialize the dataset
with tf.Session() as session:
session.run(testing_init_op)
while True:
try:
images, labels = session.run(next_testing_element)
accuracy = session.run(accuracy, feed_dict={x: test_images, y_true: test_labels}) #error here, x, y_true not defined
except tf.errors.OutOfRangeError:
break
我的问题主要是当我恢复模型时。如何向网络提供测试数据?在
testing_iterator = graph.get_operation_by_name("iterators/Iterator")
,next_testing_element = graph.get_operation_by_name("iterators/IteratorGetNext")
还原迭代器时,我得到以下错误:
GetNext() failed because the iterator has not been initialized. Ensure that you have run the initializer operation for this iterator before getting the next element.
testing_init_op = testing_iterator.make_initializer(testing_dataset))
初始化我的数据集。我得到了这个错误:AttributeError: 'Operation' object has no attribute 'make_initializer'
另一个问题是,由于使用了迭代器,所以不需要在训练模型中使用占位符,因为迭代器直接向图形提供数据。但是这样,当我把数据输入到“精度”操作时,如何恢复第3行到最后一行的feed_dict键?在
编辑:如果有人建议在迭代器和网络输入之间添加占位符,那么我可以试着运行图形,方法是计算“精度”张量,同时向占位符提供数据,完全忽略迭代器。在
我无法解决与初始化迭代器相关的问题,但由于我使用map方法预处理数据集,并且应用了用py_func包装的Python操作定义的转换,这些转换无法序列化以存储\还原,所以在我想恢复数据集时,我必须初始化数据集。在
所以,剩下的问题是,当我恢复图形时,如何将数据馈送到图形中。我放了一个tf.身份迭代器输出和网络输入之间的节点。恢复后,我将数据输入到identity节点。我后来发现的一个更好的解决方案是使用
placeholder_with_default()
,如this answer所述。在我建议使用^{} ,它正是为此目的而设计的。它的详细程度要低得多,并且不需要更改现有代码,特别是如何定义迭代器。在
例如,当我们在步骤5完成后保存所有内容时。请注意,我甚至都懒得知道使用了什么种子。在
然后,如果我们从第6步继续,我们得到相同的输出。在
^{pr2}$还原保存的元图时,可以使用名称还原初始化操作,然后再次使用它初始化输入管道进行推断。在
也就是说,在创建图形时,您可以
然后通过执行以下操作来恢复此操作:
^{pr2}$下面是一个自包含的代码片段,用于比较恢复前后随机初始化模型的结果。在
保存迭代器
然后可以恢复相同的模型进行推理,如下所示:
恢复保存的迭代器
相关问题 更多 >
编程相关推荐