Tensorflow关于mnist.train.next_batch()

2024-04-29 13:32:05 发布

您现在位置:Python中文网/ 问答频道 /正文

当我搜索mnist.train.next_batch()时发现 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/datasets/mnist.py

在这个代码中

 def next_batch(self, batch_size, fake_data=False, shuffle=True):
  """Return the next `batch_size` examples from this data set."""
  if fake_data:
  fake_image = [1] * 784
  if self.one_hot:
    fake_label = [1] + [0] * 9
  else:
    fake_label = 0
  return [fake_image for _ in xrange(batch_size)], [
      fake_label for _ in xrange(batch_size)
  ]
start = self._index_in_epoch
# Shuffle for the first epoch
if self._epochs_completed == 0 and start == 0 and shuffle:
  perm0 = numpy.arange(self._num_examples)
  numpy.random.shuffle(perm0)
  self._images = self.images[perm0]
  self._labels = self.labels[perm0]
# Go to the next epoch
if start + batch_size > self._num_examples:
  # Finished epoch
  self._epochs_completed += 1
  # Get the rest examples in this epoch
  rest_num_examples = self._num_examples - start
  images_rest_part = self._images[start:self._num_examples]
  labels_rest_part = self._labels[start:self._num_examples]
  # Shuffle the data
  if shuffle:
    perm = numpy.arange(self._num_examples)
    numpy.random.shuffle(perm)
    self._images = self.images[perm]
    self._labels = self.labels[perm]
  # Start next epoch
  start = 0
  self._index_in_epoch = batch_size - rest_num_examples
  end = self._index_in_epoch
  images_new_part = self._images[start:end]
  labels_new_part = self._labels[start:end]
  return numpy.concatenate((images_rest_part, images_new_part), axis=0) , numpy.concatenate((labels_rest_part, labels_new_part), axis=0)
else:
  self._index_in_epoch += batch_size
  end = self._index_in_epoch
  return self._images[start:end], self._labels[start:end]

我知道mnist.train.next_batch(batch_size=100)意味着它从mnist数据集中随机抽取100个数据。现在,这是我的问题

  1. shuffle=true是什么意思?
  2. 如果我设置下一个批(批大小=100,假数据=False,随机洗牌=False),那么它会从MNIST数据集的开始到结束依次选取100个数据?不是随机的?

Tags: inselfrestsizelabelsbatchstartexamples
2条回答

您可以使用以下代码:

# mnist.train.next_batch
# SHUFFLE = FASLE

import matplotlib.pyplot as plt
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("data", one_hot=True)

image_index = 10 # Extract image 10 from MNIST every time you run the code
image_index -=1  # Start at zero
# _index_in_epoch - current image_index
# Set current image_index to zero by moving backward
mnist.train.next_batch(-mnist.train._index_in_epoch, shuffle = False)
# Extract image 10 using mnist.train.next_batch
mnist.train.next_batch(image_index, shuffle = False) 
batch_x, batch_y = mnist.train.next_batch(1, shuffle = False)

print('\n'+"mnist.train.next_batch:")
plt.imshow(batch_x.reshape([28, 28]), cmap='Greys')
plt.show()
print(batch_y, np.argmax(batch_y), mnist.train._index_in_epoch)

# Extract image 10 using mnist.train.images
image_x = mnist.train.images[image_index] 
image_y = mnist.train.labels[image_index] 

print('\n'+"mnist.train.images:")
plt.imshow(image_x.reshape([28, 28]), cmap='Reds')
plt.show()
print(image_y, np.argmax(image_y), mnist.train._index_in_epoch)

shuffle=True数据中示例的顺序是随机的。是的,它应该尊重例子在numpy数组中的顺序。

相关问题 更多 >