TensorFlow在读写TFRecords文件时设置图像的形状?

2021-12-08 06:20:45 发布

您现在位置:Python中文网/ 问答频道 /正文

我在尝试使用TFRecords格式时设置图像数据的形状时遇到问题。我已经讨论了how-to for reading data,并从MNIST示例中获取了converting the image data to a TFRecords和{a3}的代码。然而,这个示例代码最初期望图像以一种格式使用,其中所有像素数据都在一个长向量中。在

我一直在尝试修改这段代码,以使用仍然保持原始图像形状的NumPy数组。所以在我下面的代码中,images是一个形状为[number_of_images, height, width, channels]的NumPy数组。我不确定我的问题是如何将数据写入TFRecords还是如何将其读回。但是,当我试图设置解码图像的形状时,我得到了错误ValueError: Shapes (?,) and (464, 624, 3) must have the same rank(注意:464x624x3是图像尺寸)。关于我可能做错了什么有什么建议吗?在

相关代码(与示例代码略有不同)

def convert_to_tfrecord(images, labels, name, data_directory):
    number_of_examples = labels.shape[0]
    rows = images.shape[1]  # images is the 4D ndarray with the images in their original shape.
    cols = images.shape[2]
    depth = images.shape[3]
    ...
    for index in range(number_of_examples):
        image_raw = images[index].tostring()
        example = tf.train.Example(features=tf.train.Features(feature={
            'height': _int64_feature(rows),
            'width': _int64_feature(cols),
            'channels': _int64_feature(depth),
            'image': _bytes_feature(image_raw),
            ...
        }))
        writer.write(example.SerializeToString())

...

def read_and_decode(filename_queue):
    ...
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image_raw': tf.FixedLenFeature([], tf.string),
            ...
        })
    ...
    image = tf.decode_raw(features['image_raw'], tf.uint8)
    image.set_shape([464, 624, 3])  # This is where the error occurs.
    image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
    ...