Tensorflow图错误列表索引超出范围

2024-04-24 12:26:05 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试实现一个tensorflow示例来训练图像并保存模型以便在opencv上使用。 火车的一部分和保存.h5文件是完美的,甚至我测试它的一些测试图像。 当我试图冻结图形时,我得到了一个列表索引超出范围的错误。 我是tensorflow的新手,请帮帮我

代码如下:

import os
import tensorflow as tf
import keras_preprocessing
from keras_preprocessing import image
from keras_preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from tensorflow.python.tools import freeze_graph

rock_dir = os.path.join('tmp/rps/rock')
paper_dir = os.path.join('tmp/rps/paper')
scissors_dir = os.path.join('tmp/rps/scissors')

print('total training rock images:', len(os.listdir(rock_dir)))
print('total training paper images:', len(os.listdir(paper_dir)))
print('total training scissors images:', len(os.listdir(scissors_dir)))

rock_files = os.listdir(rock_dir)
print(rock_files[:10])

paper_files = os.listdir(paper_dir)
print(paper_files[:10])

scissors_files = os.listdir(scissors_dir)
print(scissors_files[:10])

TRAINING_DIR = "tmp/rps/"
training_datagen = ImageDataGenerator(
      rescale = 1./255,
      rotation_range=40,
      width_shift_range=0.2,
      height_shift_range=0.2,
      shear_range=0.2,
      zoom_range=0.2,
      horizontal_flip=True,
      fill_mode='nearest')

VALIDATION_DIR = "tmp/rps-test-set/"
validation_datagen = ImageDataGenerator(rescale = 1./255)

train_generator = training_datagen.flow_from_directory(TRAINING_DIR,target_size=(150,150),class_mode='categorical')

validation_generator = validation_datagen.flow_from_directory(VALIDATION_DIR,target_size=(150,150),class_mode='categorical')
tf.compat.v1.reset_default_graph()
with tf.Graph().as_default():
     with tf.compat.v1.Session() as sess:
        model = tf.keras.models.Sequential([
                # Note the input shape is the desired size of the image 150x150 with 3 bytes color
                # This is the first convolution
                tf.keras.layers.Conv2D(64, (3,3), activation='relu', input_shape=(150, 150, 3)),
                 tf.keras.layers.MaxPooling2D(2, 2),
                # The second convolution
                tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
                tf.keras.layers.MaxPooling2D(2,2),
                # The third convolution
                tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
                tf.keras.layers.MaxPooling2D(2,2),
                # The fourth convolution
                tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
                tf.keras.layers.MaxPooling2D(2,2),
                # Flatten the results to feed into a DNN
                tf.keras.layers.Flatten(),
                tf.keras.layers.Dropout(0.5),
                # 512 neuron hidden layer
                tf.keras.layers.Dense(512, activation='relu'),
                tf.keras.layers.Dense(3, activation='softmax')
                ])

        model.summary()

        model.compile(loss = 'categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])

        history = model.fit_generator(
            train_generator,
            epochs=1,
            validation_data = validation_generator,
            verbose = 1)

        model.save("rps.h5")



        init = tf.compat.v1.global_variables_initializer()
        print(init)
        sess.run(init)
        print(sess.run(init))
        saver = tf.compat.v1.train.Saver()
        saver.save(sess, 'MODLE/tensorflowModel.ckpt')
        tf.io.write_graph(sess.graph.as_graph_def(), '.', 'MODLE/tensorflowModel.pbtxt', as_text=True)
         # tf.io.write_graph(sess.graph.as_graph_def(), '.', 'MODLE/tensorflowModel.pb', as_text=False)

        freeze_graph.freeze_graph(input_graph='MODLE/tensorflowModel.pbtxt', input_saver='',
                                   input_binary=False, input_checkpoint='MODLE/tensorflowModel.ckpt',
                                   output_node_names='softmax', restore_op_name='save/restore_all',
                                   filename_tensor_name='save/Const:0', output_graph='MODLE/tensorflowModel.pb',
                                   clear_devices=True, initializer_nodes='')

这就是错误

Use standard file APIs to check for files with this prefix.
Traceback (most recent call last):
  File "C:/Users/Sufyan/Desktop/PYTHON/TENSORFLOW/rock paper.py", line 95, in <module>
    clear_devices=True, initializer_nodes='')
  File "C:\Users\Sufyan\Desktop\PYTHON\TENSORFLOW\venv\lib\site-packages\tensorflow\python\tools\freeze_graph.py", line 363, in freeze_graph
    checkpoint_version=checkpoint_version)
  File "C:\Users\Sufyan\Desktop\PYTHON\TENSORFLOW\venv\lib\site-packages\tensorflow\python\tools\freeze_graph.py", line 190, in freeze_graph_with_def_protos
    var_list=var_list, write_version=checkpoint_version)
  File "C:\Users\Sufyan\Desktop\PYTHON\TENSORFLOW\venv\lib\site-packages\tensorflow\python\training\saver.py", line 831, in __init__
    self.build()
  File "C:\Users\Sufyan\Desktop\PYTHON\TENSORFLOW\venv\lib\site-packages\tensorflow\python\training\saver.py", line 843, in build
    self._build(self._filename, build_save=True, build_restore=True)
  File "C:\Users\Sufyan\Desktop\PYTHON\TENSORFLOW\venv\lib\site-packages\tensorflow\python\training\saver.py", line 880, in _build
    build_save=build_save, build_restore=build_restore)
  File "C:\Users\Sufyan\Desktop\PYTHON\TENSORFLOW\venv\lib\site-packages\tensorflow\python\training\saver.py", line 486, in _build_internal
    names_to_saveables)
  File "C:\Users\Sufyan\Desktop\PYTHON\TENSORFLOW\venv\lib\site-packages\tensorflow\python\training\saving\saveable_object_util.py", line 341, in validate_and_slice_inputs
    for converted_saveable_object in saveable_objects_for_op(op, name):
  File "C:\Users\Sufyan\Desktop\PYTHON\TENSORFLOW\venv\lib\site-packages\tensorflow\python\training\saving\saveable_object_util.py", line 207, in saveable_objects_for_op
    variable, "", name)
  File "C:\Users\Sufyan\Desktop\PYTHON\TENSORFLOW\venv\lib\site-packages\tensorflow\python\training\saving\saveable_object_util.py", line 83, in __init__
    self.handle_op = var.op.inputs[0]
  File "C:\Users\Sufyan\Desktop\PYTHON\TENSORFLOW\venv\lib\site-packages\tensorflow\python\framework\ops.py", line 2355, in __getitem__
    return self._inputs[i]
IndexError: list index out of range

我希望在这里得到答案


Tags: inpybuildlayerstftensorflowlinetraining