当模型需要输入张量时,如何正确使用TFRecord数据集张量进行训练和验证?

2024-04-19 10:03:15 发布

您现在位置:Python中文网/ 问答频道 /正文

我有两个由许多图像组成的TFRecord数据集。对于培训,我有以下(伪)代码:

training_dataset = tf.data.TFRecordDataset(records)
iterator = training_dataset.make_initializable_iterator()
train_image_tensor, train_label_tensor= iterator.get_next() # image and label tensor
preds = build_model(input_tensor=train_image_tensor, 'efficientnet-b5', training=True)
cross_entropy_op = tf.reduce_mean(tf.losses.softmax_cross_entropy(onehot_labels=train_label_tensor, logits=preds, weights=1.0))
(..)
while True:
   sess.run([cross_entropy_op])
   (...)

现在我想添加一个验证循环,在每个历元之后使用另一个TFRecord数据集验证我的模型。问题是,当我已经构建图形时,我必须向模型(train_image_tensor)提供输入张量。所以我需要在验证循环中交换模型的输入张量。有什么简单的方法吗

谢谢


Tags: 数据模型imagetruetftfrecordtrainingtrain