张量来自不同的图

2024-03-28 10:05:15 发布

您现在位置:Python中文网/ 问答频道 /正文

我不熟悉tensorflow。正在尝试从tfrecords创建输入管道。 下面是我的代码片段,用于创建批并输入到我的estimator

def generate_input_fn(image,label,batch_size=BATCH_SIZE):
    logging.info('creating batches...')    
    dataset = tf.data.Dataset.from_tensors((image, label)) #<-- dataset is 'TensorDataset'
    dataset = dataset.repeat().batch(batch_size)
    iterator=dataset.make_initializable_iterator()
    iterator.initializer
    return iterator.get_next()

iterator=dataset.make_initializable_iterator()

ValueError: Tensor("count:0", shape=(), dtype=int64, device=/device:CPU:0) must be from the same graph as Tensor("TensorDataset:0", shape=(), dtype=variant).

我想我无意中使用了来自不同图形的张量,但我不知道如何以及在哪一行代码中。我不知道哪个张量是计数:0 或者whichone十orDataset:0。在

有人能帮我调试一下吗。在

错误日志:

^{pr2}$

如果我将函数修改为:

image_placeholder=tf.placeholder(image.dtype,shape=image.shape)
label_placeholder=tf.placeholder(label.dtype,shape=label.shape)
dataset = tf.data.Dataset.from_tensors((image_placeholder, label_placeholder))

即添加占位符,然后我得到输出:

INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
2018-03-18 01:56:55.902917: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
Killed

Tags: 代码fromimageinfosizetftensorflowbatch
1条回答
网友
1楼 · 发布于 2024-03-28 10:05:15

当您调用estimator.train(input_fn)时,将创建一个新的图,其中包含在估计器的model_fn中定义的图,以及在{}中定义的图。在

因此,如果这些函数中的任何一个从它们的作用域之外引用张量,这些张量将不是同一个图的一部分,您将得到一个错误。在


简单的解决方案是确保您定义的每个张量都在input_fn或{}内部。在

例如:

def generate_input_fn(batch_size):
    # Create the images and labels tensors here
    images = tf.placeholder(tf.float32, [None, 224, 224, 3])
    labels = tf.placeholder(tf.int64, [None])

    dataset = tf.data.Dataset.from_tensors((images, labels))
    dataset = dataset.repeat()
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(1)
    iterator = dataset.make_initializable_iterator()

    return iterator.get_next()

相关问题 更多 >