Tensorflow:从numpy数组创建minibatch>2 GB

2024-04-19 18:56:41 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图给我的模型提供小批量的numpy数组,但我仍然坚持批处理。使用'tf.train.shuffle_批处理'引发错误,因为“images”数组大于2 GB。我试图绕过它并创建占位符,但是当我试图为数组提供数据时,它们仍然由tf.张量物体。我主要关心的是,我在model类下定义了操作,并且在运行会话之前不会调用这些对象。有人知道如何处理这个问题吗?在

def main(mode, steps):
  config = Configuration(mode, steps)



  if config.TRAIN_MODE:

      images, labels = read_data(config.simID)

      assert images.shape[0] == labels.shape[0]

      images_placeholder = tf.placeholder(images.dtype,
                                                images.shape)
      labels_placeholder = tf.placeholder(labels.dtype,
                                                labels.shape)

      dataset = tf.data.Dataset.from_tensor_slices(
                (images_placeholder, labels_placeholder))

      # shuffle
      dataset = dataset.shuffle(buffer_size=1000)

      # batch
      dataset = dataset.batch(batch_size=config.batch_size)

      iterator = dataset.make_initializable_iterator()

      image, label = iterator.get_next()

      model = Model(config, image, label)

      with tf.Session() as sess:

          sess.run(tf.global_variables_initializer())

          sess.run(iterator.initializer, 
                   feed_dict={images_placeholder: images,
                          labels_placeholder: labels})

          # ...

          for step in xrange(steps):

              sess.run(model.optimize)

Tags: configsizelabelsmodeltfbatch数组steps
1条回答
网友
1楼 · 发布于 2024-04-19 18:56:41

您正在使用tf.Datainitializable iterator向模型提供数据。这意味着您可以根据占位符参数化数据集,然后调用迭代器的初始值设定项op来准备使用它。在

如果您使用可初始化的迭代器或来自tf.Data的任何其他迭代器向模型提供输入,则不应使用sess.runfeed_dict参数来尝试进行数据馈送。相反,根据iterator.get_next()的输出定义模型,并省略sess.run中的{}。在

大致如下:

iterator = dataset.make_initializable_iterator()
image_batch, label_batch = iterator.get_next()

# use get_next outputs to define model
model = Model(config, image_batch, label_batch) 

# placeholders fed in while initializing the iterator
sess.run(iterator.initializer, 
            feed_dict={images_placeholder: images,
                       labels_placeholder: labels})

for step in xrange(steps):
     # iterator will feed image and label in the background
     sess.run(model.optimize) 

迭代器将在后台向模型提供数据,不需要通过feed_dict进行额外的数据馈送。在

相关问题 更多 >