当`tf.data.TFRecordsDataset`的num_parallel_read=2并从tf.dataset中提取时,它会在开始之前提取8个文件,而不是预期的2个

2024-04-25 05:43:58 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在从生成器数据集中向TFRecordDataset馈送tfrecord文件名,该生成器数据集处理根据需要在本地获取和放置tfrecord文件。tfrecords文件很大,总体上不适合本地文件系统。为了提高传输速度,我将它们放在ram磁盘上,因此我关心一次排队的人数

ds = tf.data.Dataset.from_generator(generator=self.gen_tfrecords_files, output_types=tf.string, output_shapes=())
ds = tf.data.TFRecordDataset(filenames=ds, num_parallel_reads=2)

我希望TFRecordDataset从我的生成器中一次采样2个文件,但是它在开始训练之前从生成器中提取了8个文件。我没有任何重要的预取在管道中

有人能解释为什么TFRecordDatasetnum_parallel_read所指定的文件要多出这么多吗

TF 1.14.0


Tags: 文件数据outputdataparallel文件名tftfrecords