我发现以下方法可以在tensorflow中获取mnist数据集:
def get_input_fn(dataset_split, batch_size, capacity=10000, min_after_dequeue=3000):
def _input_fn():
images_batch, labels_batch = tf.train.shuffle_batch(
tensors=[dataset_split.images, dataset_split.labels.astype(np.int32)],
batch_size=batch_size,
capacity=capacity,
min_after_dequeue=min_after_dequeue,
enqueue_many=True,
num_threads=4)
features_map = {'images': images_batch}
return features_map, labels_batch
return _input_fn
data = tf.contrib.learn.datasets.mnist.load_mnist()
train_input_fn = get_input_fn(data.train, batch_size=256)
eval_input_fn = get_input_fn(data.validation, batch_size=5000)
数据变量是数据集对象。 这种方法我不太清楚,我也不知道如何将60K数据集转换为10K数据集。在
当我执行以下操作时:
^{pr2}$我得到错误:
AttributeError: 'Datasets' object has no attribute 'take'
谢谢你的帮助!在
来自contrib模块的此函数已弃用。您可以使用
tf.keras.datasets.mnist.load_data()
。根据https://www.tensorflow.org/api_docs/python/tf/keras/datasets/mnist/load_data,它返回因此,为了对它应用任何函数,需要将其加载到dataset对象中。在
^{pr2}$然后可以将shuffle、batch、take或任何映射函数应用于
dataset_train
或dataset_test
对象相关问题 更多 >
编程相关推荐