从TFRecordDatas获取数据集为numpy数组

2024-06-08 12:54:28 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用新的tf.dataAPI为CIFAR10数据集创建一个迭代器。我正在从两个.tfrecord文件读取数据。保存训练数据的人(train.tf记录)另一个是保存测试数据的(测试.tf记录). 这一切都很好。然而,在某些时候,我需要两个数据集(训练数据和测试数据)作为numpy数组。在

是否可以从tf.data.TFRecorddataset对象中以numpy数组的形式检索数据集?在


Tags: 文件数据对象numpydatatftfrecord记录
1条回答
网友
1楼 · 发布于 2024-06-08 12:54:28

您可以使用^{}转换和^{}来完成此操作。 作为复习,dataset.batch(n) 将占用n个连续的n个元素,并通过连接每个组件将它们转换为一个元素。这要求所有元素的每个组件都有一个固定的形状。如果n大于dataset中的元素数量(或者如果n没有精确划分元素的数量),那么最后一批可以更小。因此,您可以为n选择一个较大的值并执行以下操作:

import numpy as np
import tensorflow as tf

# Insert your own code for building `dataset`. For example:
dataset = tf.data.TFRecordDataset(...)  # A dataset of tf.string records.
dataset = dataset.map(...)  # Extract components from each tf.string record.

# Choose a value of `max_elems` that is at least as large as the dataset.
max_elems = np.iinfo(np.int64).max
dataset = dataset.batch(max_elems)

# Extracts the single element of a dataset as one or more `tf.Tensor` objects.
# No iterator needed in this case!
whole_dataset_tensors = tf.contrib.data.get_single_element(dataset)

# Create a session and evaluate `whole_dataset_tensors` to get arrays.
with tf.Session() as sess:
    whole_dataset_arrays = sess.run(whole_dataset_tensors)

相关问题 更多 >