两个都喂tf.data.Dataset和NumPy数组到mod

2024-04-27 00:34:20 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个表示模型的类,其设置如下:

class Model:
  def __init__(self):
    self.setup_graph()

  def setup_graph():
    # sets up the model
    ....

  def train(self, dataset):
    # dataset is a tf.data.Dataset iterator, from which I can get 
    # tf.Tensor objects directly, which become part of the graph
    ....

  def predict(self, sample):
    # sample is a single NumPy array representing a sample,
    # which could be fed to a tf.placeholder using feed_dict
    ....

在培训期间,我希望利用TensorFlow的tf.data.Dataset的效率,但我仍然希望能够在单个样本上获得模型的输出。在我看来,这需要重新创建用于预测的图表。这是真的吗,或者我可以创建一个TF图,在这个图中,我可以使用来自tf.data.Dataset的样本运行,也可以用我提供给tf.placeholder的给定样本运行?在


Tags: thesample模型selfwhichdataistf
1条回答
网友
1楼 · 发布于 2024-04-27 00:34:20

您可以像往常一样使用数据集、迭代器等创建模型。然后,如果要用feed_dict传递一些自定义数据,只需将值传递给get_next()生成的张量:

import tensorflow as tf
import numpy as np

dataset = (tf.data.Dataset
    .from_tensor_slices(np.ones((100, 3), dtype=np.float32))
    .batch(5))
iterator = dataset.make_one_shot_iterator()
batch = iterator.get_next()

output = 2 * batch

with tf.Session() as sess:
    print('From iterator:')
    print(sess.run(output))
    print('From feed_dict:')
    print(sess.run(output, feed_dict={batch: [[1, 2, 3]]}))

输出:

^{pr2}$

原则上,使用可初始化的、可重新初始化的或可馈送的迭代器可以达到相同的效果,但如果您真的只想测试单个数据样本,我认为这是最快、干扰更少的方法。在

相关问题 更多 >