如何从tensorflow中的Dataset类中获取10kmnist图像的子集?

2024-05-16 12:46:39 发布

您现在位置:Python中文网/ 问答频道 /正文

我发现以下方法可以在tensorflow中获取mnist数据集:

def get_input_fn(dataset_split, batch_size, capacity=10000, min_after_dequeue=3000):

  def _input_fn():
    images_batch, labels_batch = tf.train.shuffle_batch(
        tensors=[dataset_split.images, dataset_split.labels.astype(np.int32)],
        batch_size=batch_size,
        capacity=capacity,
        min_after_dequeue=min_after_dequeue,
        enqueue_many=True,
        num_threads=4)
    features_map = {'images': images_batch}
    return features_map, labels_batch

  return _input_fn

    data = tf.contrib.learn.datasets.mnist.load_mnist()

    train_input_fn = get_input_fn(data.train, batch_size=256)
    eval_input_fn = get_input_fn(data.validation, batch_size=5000)

数据变量是数据集对象。 这种方法我不太清楚,我也不知道如何将60K数据集转换为10K数据集。在

当我执行以下操作时:

^{pr2}$

我得到错误:

AttributeError: 'Datasets' object has no attribute 'take'

但是文档提供了这种方法: enter image description here

谢谢你的帮助!在


Tags: 数据方法inputsizegetbatchmindataset
1条回答
网友
1楼 · 发布于 2024-05-16 12:46:39

来自contrib模块的此函数已弃用。您可以使用tf.keras.datasets.mnist.load_data()。根据https://www.tensorflow.org/api_docs/python/tf/keras/datasets/mnist/load_data,它返回

Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. 

因此,为了对它应用任何函数,需要将其加载到dataset对象中。在

^{pr2}$

然后可以将shuffle、batch、take或任何映射函数应用于dataset_traindataset_test对象

相关问题 更多 >