如何将基于feedbased的基本TensorFlow代码转换为使用“Dataset”?

2024-03-28 13:04:47 发布

您现在位置:Python中文网/ 问答频道 /正文

understand that there are advantages(特别是当我扩展了我构建的模型的范围和它们处理的数据集的大小)使用TensorFlow的新^{}作为我的数据供给管道的惯用用法。但是,我很难将现有的基于feed_dict的代码映射到这个新模型。你知道吗

我面临的一个问题是,我无法理清批处理和epoch是如何交互的,也无法理清它们是如何与我经常做的日志记录和验证交织在一起的。你知道吗

例如,下面这样的内容如何映射到使用Dataset?你知道吗

# Load and process data into tensors of dimension (N, C_i) for input and (N, C_o) for output
# where N is the number of examples and C_ is the number of chanels, and the values are activations
train_x, train_y, valid_x, valid_y = load_data(file, [segments], ...)
train_size = len(train_x)

train_stats_feed = {input_activation: train_x, correct_output: train_y, is_train: False}
valid_stats_feed = {input_activation: valid_x, correct_output: valid_y, is_train: False}

with tf.Session(config=tf.ConfigProto(...)) as sess:
    sess.run(tf.initialize_all_variables())

    # Some analysis; not always done but the code needs to support it
    train_writer.add_summary(sess.run(merged, feed_dict=train_stats_feed), 0)
    test_writer.add_summary(sess.run(merged, feed_dict=valid_stats_feed), 0)

    test_writer.add_summary(sess.run(gs_summary), 0)

    print(log_fmt.format(0, float(sess.run(accuracy, feed_dict=valid_stats_feed)),
                         float(sess.run(loss, feed_dict=valid_stats_feed))))

    for ep in range(epochs):
        # Slice the training data into random batches
        batch_indices = np.array_split(np.random.permutation(train_size), int(train_size/mb_size))

        for mini_batch_indices in batch_indices:
            sess.run(train_step, feed_dict={input_activation: train_x[mini_batch_indices],
                                            correct_output: train_y[mini_batch_indices], is_train: True})

            gs = int(sess.run(global_step))
            if gs % log_steps == 0:
                test_writer.add_summary(sess.run(merged, feed_dict=valid_stats_feed), gs)
                train_writer.add_summary(sess.run(merged, feed_dict=train_stats_feed), gs)

                acc = float(sess.run(accuracy, feed_dict=valid_stats_feed))
                sess.run(validation_accuracy.assign(acc))

                print(log_fmt.format(gs, acc, float(sess.run(loss, feed_dict=valid_stats_feed))))

        print(ep_fmt.format(ep + 2))
        test_writer.add_summary(sess.run(gs_summary), ep + 1)

如果需要,上面的一些不太明显的定义:

# Preliminaries

# Some basic preliminaries, the details of which are not important to the question
# Mostly pretty standard; obvious things omitted from MWE for brevity
global_step = tf.Variable(0, trainable=False, name='global_step')
validation_accuracy = tf.Variable(0.0, trainable=False, name='validation_accuracy', dtype=tf.float32)

is_train = tf.placeholder(tf.bool, [], name='is_train')
input_activation = tf.placeholder(tf.float32, shape=[None, in_nodes], name='inputs')
correct_output = tf.placeholder(tf.float32, shape=[None, out_nodes], name='correct_outputs')

network_output = tf.identity(out_activations)
correct_predictions = correct_fn(correct_output, network_output)
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))
error = cost_fn(correct_output, network_output)
loss = error + FLAGS.regularization_weight * sum(tf.nn.l2_loss(w) for w in layer_weights)

train_step = tf.train.MomentumOptimizer(learning_rate, momentum=momentum).minimize(loss, global_step=global_step)

# Logging
train_writer = tf.summary.FileWriter(trainlogfile, tf.get_default_graph())
test_writer = tf.summary.FileWriter(testlogfile, tf.get_default_graph())
gs_summary = tf.summary.scalar('global_step_at_epoch', global_step)
merged = tf.summary.merge_all()

Tags: rungsoutputistffeedstatsstep
1条回答
网友
1楼 · 发布于 2024-03-28 13:04:47

这里有几行训练开始。同样的逻辑也适用于验证

# Define placeholder for inputs data and labels
inputs_placeholder = tf.placeholder(train_x.dtype, train_x.shape)
labels_placeholder = tf.placeholder(train_y.dtype, train_y.shape)
# Define a Dataset object using the above placeholders
dataset = tf.contrib.data.Dataset.from_tensor_slices((inputs_placeholder,      labels_placeholder))
# Define batch_size
batch_size = 128
dataset = dataset.batch(batch_size)
# Define iterator
iterator = dataset.make_initializable_iterator()
# Get one batch
next_example, next_label = iterator.get_next()
# calculate loss from the model fucntion you are using
loss = some_model(next_example, next_label)
# Set number of Epochs here
num_epochs = 100
for _ in range(num_epochs):
    sess.run(iterator.initializer, feed_dict={inputs_placeholder: train_x, labels_placeholder: train_y}))
    while True:
        try:
            _loss = sess.run(loss)
        except tf.errors.OutOfRangeError:
            break

相关问题 更多 >