使用tf.data API替换tf.placeholder和feed_dict

2024-04-18 01:00:13 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个现有的TensorFlow模型,它使用tf.placeholder作为模型输入,使用tf.Session()的feed_dict参数运行以传入数据。以前,整个数据集被读入内存并以这种方式传递。

我想使用更大的数据集,并利用tf.data API的性能改进。我已经定义了一个tf.data.TextLinedataset和它的一次迭代,但是我很难弄清楚如何将数据输入到模型中来训练它。

起初,我试图将feed_dict定义为从占位符到iterator.get_next()的字典,但这给了我一个错误,即feed的值不能是tf.Tensor对象。更深入的研究使我明白,这是因为iterator.get_next()返回的对象已经是图形的一部分,这与您将要馈送到feed_dict中的对象不同——而且出于性能原因,我根本不应该尝试使用feed_dict。

所以现在我去掉了输入tf.placeholder,并用一个参数替换它到定义我的模型的类的构造函数中;在我的训练代码中构造模型时,我将iterator.get_next()的输出传递给该参数。这看起来已经有点笨拙了,因为它打破了模型定义和数据集/训练过程之间的分离。我现在得到一个错误,表示(我相信)我的模型输入的张量必须来自迭代器中的张量。

我用这种方法是否走在正确的轨道上,只是在如何设置图表和会话方面做了一些错误的事情,或者类似的事情?(数据集和模型都是在会话外部初始化的,错误发生在尝试创建会话之前。)

或者我完全不了解这一点,需要做一些不同的事情,比如使用估计器API并在输入函数中定义所有内容?

下面是一些演示最小示例的代码:

import tensorflow as tf
import numpy as np

class Network:
    def __init__(self, x_in, input_size):
        self.input_size = input_size
        # self.x_in = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size))  # Original
        self.x_in = x_in
        self.output_size = 3

        tf.reset_default_graph()  # This turned out to be the problem

        self.layer = tf.layers.dense(self.x_in, self.output_size, activation=tf.nn.relu)
        self.loss = tf.reduce_sum(tf.square(self.layer - tf.constant(0, dtype=tf.float32, shape=[self.output_size])))

data_array = np.random.standard_normal([4, 10]).astype(np.float32)
dataset = tf.data.Dataset.from_tensor_slices(data_array).batch(2)

model = Network(x_in=dataset.make_one_shot_iterator().get_next(), input_size=dataset.output_shapes[-1])

Tags: 数据in模型selfinputdatasizeget

热门问题