我正在尝试实现一个tensorflow示例来训练图像并保存模型以便在opencv上使用。 火车的一部分和保存.h5文件是完美的,甚至我测试它的一些测试图像。 当我试图冻结图形时,我得到了一个列表索引超出范围的错误。 我是tensorflow的新手,请帮帮我
代码如下:
import os
import tensorflow as tf
import keras_preprocessing
from keras_preprocessing import image
from keras_preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from tensorflow.python.tools import freeze_graph
rock_dir = os.path.join('tmp/rps/rock')
paper_dir = os.path.join('tmp/rps/paper')
scissors_dir = os.path.join('tmp/rps/scissors')
print('total training rock images:', len(os.listdir(rock_dir)))
print('total training paper images:', len(os.listdir(paper_dir)))
print('total training scissors images:', len(os.listdir(scissors_dir)))
rock_files = os.listdir(rock_dir)
print(rock_files[:10])
paper_files = os.listdir(paper_dir)
print(paper_files[:10])
scissors_files = os.listdir(scissors_dir)
print(scissors_files[:10])
TRAINING_DIR = "tmp/rps/"
training_datagen = ImageDataGenerator(
rescale = 1./255,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest')
VALIDATION_DIR = "tmp/rps-test-set/"
validation_datagen = ImageDataGenerator(rescale = 1./255)
train_generator = training_datagen.flow_from_directory(TRAINING_DIR,target_size=(150,150),class_mode='categorical')
validation_generator = validation_datagen.flow_from_directory(VALIDATION_DIR,target_size=(150,150),class_mode='categorical')
tf.compat.v1.reset_default_graph()
with tf.Graph().as_default():
with tf.compat.v1.Session() as sess:
model = tf.keras.models.Sequential([
# Note the input shape is the desired size of the image 150x150 with 3 bytes color
# This is the first convolution
tf.keras.layers.Conv2D(64, (3,3), activation='relu', input_shape=(150, 150, 3)),
tf.keras.layers.MaxPooling2D(2, 2),
# The second convolution
tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
# The third convolution
tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
# The fourth convolution
tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
# Flatten the results to feed into a DNN
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.5),
# 512 neuron hidden layer
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dense(3, activation='softmax')
])
model.summary()
model.compile(loss = 'categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
history = model.fit_generator(
train_generator,
epochs=1,
validation_data = validation_generator,
verbose = 1)
model.save("rps.h5")
init = tf.compat.v1.global_variables_initializer()
print(init)
sess.run(init)
print(sess.run(init))
saver = tf.compat.v1.train.Saver()
saver.save(sess, 'MODLE/tensorflowModel.ckpt')
tf.io.write_graph(sess.graph.as_graph_def(), '.', 'MODLE/tensorflowModel.pbtxt', as_text=True)
# tf.io.write_graph(sess.graph.as_graph_def(), '.', 'MODLE/tensorflowModel.pb', as_text=False)
freeze_graph.freeze_graph(input_graph='MODLE/tensorflowModel.pbtxt', input_saver='',
input_binary=False, input_checkpoint='MODLE/tensorflowModel.ckpt',
output_node_names='softmax', restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0', output_graph='MODLE/tensorflowModel.pb',
clear_devices=True, initializer_nodes='')
这就是错误
Use standard file APIs to check for files with this prefix.
Traceback (most recent call last):
File "C:/Users/Sufyan/Desktop/PYTHON/TENSORFLOW/rock paper.py", line 95, in <module>
clear_devices=True, initializer_nodes='')
File "C:\Users\Sufyan\Desktop\PYTHON\TENSORFLOW\venv\lib\site-packages\tensorflow\python\tools\freeze_graph.py", line 363, in freeze_graph
checkpoint_version=checkpoint_version)
File "C:\Users\Sufyan\Desktop\PYTHON\TENSORFLOW\venv\lib\site-packages\tensorflow\python\tools\freeze_graph.py", line 190, in freeze_graph_with_def_protos
var_list=var_list, write_version=checkpoint_version)
File "C:\Users\Sufyan\Desktop\PYTHON\TENSORFLOW\venv\lib\site-packages\tensorflow\python\training\saver.py", line 831, in __init__
self.build()
File "C:\Users\Sufyan\Desktop\PYTHON\TENSORFLOW\venv\lib\site-packages\tensorflow\python\training\saver.py", line 843, in build
self._build(self._filename, build_save=True, build_restore=True)
File "C:\Users\Sufyan\Desktop\PYTHON\TENSORFLOW\venv\lib\site-packages\tensorflow\python\training\saver.py", line 880, in _build
build_save=build_save, build_restore=build_restore)
File "C:\Users\Sufyan\Desktop\PYTHON\TENSORFLOW\venv\lib\site-packages\tensorflow\python\training\saver.py", line 486, in _build_internal
names_to_saveables)
File "C:\Users\Sufyan\Desktop\PYTHON\TENSORFLOW\venv\lib\site-packages\tensorflow\python\training\saving\saveable_object_util.py", line 341, in validate_and_slice_inputs
for converted_saveable_object in saveable_objects_for_op(op, name):
File "C:\Users\Sufyan\Desktop\PYTHON\TENSORFLOW\venv\lib\site-packages\tensorflow\python\training\saving\saveable_object_util.py", line 207, in saveable_objects_for_op
variable, "", name)
File "C:\Users\Sufyan\Desktop\PYTHON\TENSORFLOW\venv\lib\site-packages\tensorflow\python\training\saving\saveable_object_util.py", line 83, in __init__
self.handle_op = var.op.inputs[0]
File "C:\Users\Sufyan\Desktop\PYTHON\TENSORFLOW\venv\lib\site-packages\tensorflow\python\framework\ops.py", line 2355, in __getitem__
return self._inputs[i]
IndexError: list index out of range
我希望在这里得到答案
目前没有回答
相关问题 更多 >
编程相关推荐