如何结合密集层使用TensorFlow数据集API

2024-04-18 00:09:34 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在试用datasets#using_high-level_apis" rel="nofollow noreferrer">TensorFlow documentation中显示的输入管道的数据集API,并使用几乎相同的代码:

tr_data = Dataset.from_tensor_slices((train_images, train_labels))
tr_data = tr_data.map(input_parser, NUM_CORES, output_buffer_size=2000)
tr_data = tr_data.batch(BATCH_SIZE)
tr_data = tr_data.repeat(EPOCHS)

iterator = dataset.make_one_shot_iterator()
next_example, next_label = iterator.get_next()

# Script throws error here
loss = model_function(next_example, next_label)

with tf.Session(...) as sess:
    sess.run(tf.global_variables_initializer())

     while True:
        try:
            train_loss = sess.run(loss)
        except tf.errors.OutOfRangeError:
            print("End of training dataset.")
            break

这应该更快,因为它避免了使用慢的feed-dicts。但是我不能用我的模型,它是一个简化的LeNet架构。问题是我的model_function()中的tf.layers.dense,它需要一个已知的输入形状(我猜是因为它必须事先知道权重的数目)。但是next_example和{}只能通过在会话中运行它们来获得它们的形状。在计算它们之前,它们的形状是未定义的?

声明model_function()会引发以下错误:

ValueError: The last dimension of the inputs to Dense should be defined. Found None.

现在,我不知道我是按预期的方式使用这个数据集API,还是有解决办法。在

提前谢谢!在

编辑1: 下面是我的模型,它在第一个密集层抛出错误

^{pr2}$

编辑2:

这里你看到张量的指纹。请注意,下一个例子没有形状

next_example: Tensor("IteratorGetNext:0", dtype=float32)
next_label: Tensor("IteratorGetNext:1", shape=(?, 4), dtype=float32)


Tags: 数据apidatamodelexampletffunctiontrain
1条回答
网友
1楼 · 发布于 2024-04-18 00:09:34

我自己找到了答案。在

接下来的thread最简单的解决方法是,如果事先知道图像大小,只需使用tf.Tensor.set_shape设置形状。在

def input_parser(img_path, label):

    # read the img from file
    img_file = tf.read_file(img_path)
    img_decoded = tf.image.decode_image(img_file, channels=1)
    img_decoded = tf.image.convert_image_dtype(img_decoded, dtype=tf.float32)
    img_decoded.set_shape([90,160,1]) # This line was missing

    return img_decoded, label

如果tensorflow文档包含这行代码,那就太好了。在

相关问题 更多 >

    热门问题