关于从tfrecords读取数据时的形状

2024-06-01 00:25:22 发布

您现在位置:Python中文网/ 问答频道 /正文

我要读tfrecords的'image'(2000)和'landmarks'(388)。在

这是代码的一部分。在

filename_queue = tf.train.string_input_producer([savepath])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example, features={'label': tf.FixedLenFeature([], tf.string), 'img_raw':tf.FixedLenFeature([], tf.string), })

image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [224, 224, 3])
image = tf.cast(image, tf.float32)

label = tf.decode_raw(features['label'], tf.float64) # problem is here
label = tf.cast(label, tf.float32)
label = tf.reshape(label, [388])

错误是

^{pr2}$

当我将“float64”更改为“float32”时:

 label = tf.decode_raw(features['label'], tf.float32) # problem is here

 #Error: InvalidArgumentError (see above for traceback): Input to reshape is a tensor with 582 values, but the requested shape has 388

或“浮动16”:

label = tf.decode_raw(features['label'], tf.float16) # problem is here

#Error: InvalidArgumentError (see above for traceback): Input to reshape is a tensor with 1164 values, but the requested shape has 388

下面是我如何制作tfrecords:(为了简单起见,我简化了一些代码)

writer = tf.python_io.TFRecordWriter(savepath)
for i in range(number_of_images):
    img = Image.open(ImagePath[i])  # load one image from path
    landmark = landmark_read_from_csv[i]  # shape of landmark_read_from_csv is (number_of_images, 388)
    example = tf.train.Example(features=tf.train.Features(feature={
    "label": tf.train.Feature(bytes_list=tf.train.BytesList(value=[landmark.tobytes()])),
    'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.tobytes()]))}))
    writer.write(example.SerializeToString())
writer.close()

我有三个问题:

  1. 为什么数据类型更改后形状会改变?在
  2. 如何选择合适的数据类型?(因为有时我可以用'tf.float64型'但有时'tf.uint8型'使用不同的数据集)
  3. 创建tfrecords的代码有问题吗?在

Tags: 代码imageimgrawisexampletftfrecords
1条回答
网友
1楼 · 发布于 2024-06-01 00:25:22

我最近遇到了一个非常相似的问题,从我的个人经验来看,我很有信心能够推断出你所问问题的答案,尽管我不是百分之百确定。在

  1. 列表项形状改变是因为不同的数据类型在编码为byte列表时具有不同的长度,并且由于float16的长度是float32的一半,同一字节列表可以作为nfloat32值的序列或两倍数量的float16值来读取。换言之,当您更改数据类型时,您尝试解码的byte列表不会发生变化,但会改变的是您对该数组列表所做的分区。

  2. 您应该检查用于生成tfrecord文件的数据类型,并在读取字节列表时使用相同的数据类型对字节列表进行解码(可以使用.dtype属性检查numpy数组的数据类型)。

  3. 我看不到但我可能错了。

相关问题 更多 >