检查.tfrecords fi

2024-04-16 04:44:46 发布

您现在位置:Python中文网/ 问答频道 /正文

我创造了图像.tfrecoreds文件使用以下代码

from PIL import Image
import numpy as np
import tensorflow as tf
import glob

images = glob.glob('E:\Projects/FYPT/vehicle/bus/*.jpg')

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

tfrecords_filename = 'E:\Projects/FYPT/vehicle/images.tfrecords'

writer = tf.python_io.TFRecordWriter(tfrecords_filename)

original_images = []

for img_path in images:
    img = np.array(Image.open(img_path))


height = img.shape[0]
width = img.shape[1]

# Put in the original images into array
# Just for future check for correctness
original_images.append((img))

img_raw = img.tostring()

example = tf.train.Example(features=tf.train.Features(feature={
    'height': _int64_feature(height),
    'width': _int64_feature(width),
    'image_raw': _bytes_feature(img_raw)
    }))

writer.write(example.SerializeToString())

writer.close()

然后我试着检查tf.tf读写器()通过打印“序列化示例”的输出

^{pr2}$

但是它给出了以下警告,并且没有给出“serialized_example”this是命令行的屏幕截图

我犯了什么错误?我应该如何打印“序列化的”示例的输出


Tags: importimgforbytesvaluetftfrecordstrain
1条回答
网友
1楼 · 发布于 2024-04-16 04:44:46

您会收到该警告,因为您使用的是返回队列的tf.train.string_input_producer(),但是基于QueueRunner API的输入管道已被弃用,在将来的版本中不受支持。在

基于队列的解决方案-不推荐!

serialized_example只是一个字符串对象(每个示例都是用tf.python_io.TFRecordWriter写入images.tfrecords文件的字符串对象)。在

您需要解析每个示例以获得其特性。在您的情况下:

features = tf.parse_single_example(serialized_example,
                                   features={"image_raw": tf.FixedLenFeature([], tf.string),
                                             "height": tf.FixedLenFeature([], tf.int64) }

img_raw = tf.image.decode_jpeg(features["image_raw"])
img_height = features["height"]

# initialize global and local variables
init_op = tf.group(tf.local_variables_initializer(),
                   tf.global_variables_initializer())

with tf.Session() as sess:
  sess.run(init_op)

  # start a number of threads
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=sess, coord=coord)

  try:
    while not coord.should_stop():
     img_raw_value, img_height_value = sess.run([img_raw, img_height])
     print(img_raw_value.shape)
     print(img_height_value)
  except tf.errors.OutOfRangeError:
     print("End of data")
  finally:
     coord.request_stop()

  # wait for all threads to terminate
  coord.join(threads)
  sess.close()

数据集API-强烈推荐!

有关如何构建输入管道的详细说明可以在here: TensorFlow API中找到。在

在您的例子中,您应该定义一个_parse_function,如下所示:

^{pr2}$

然后创建一个从TFRecord文件读取所有示例的数据集,并提取以下功能:

dataset = tf.data.TFRecordDataset([tfrecords_filename])
dataset = dataset.map(_parse_function)
# here you could batch and shuffle

iterator = dataset.make_one_shot_iterator()

next_element = iterator.get_next()

with tf.Session() as sess: 
  while True:
    try:
      val = sess.run(next_element)
      print("img_raw:", val[0].shape)
      print("height:", val[1])
      print("width:", val[2])
    except tf.errors.OutOfRangeError:
      print("End of dataset")
      break 

我希望这有帮助。在

相关问题 更多 >