ndarray到TFRecord的序列化速度较慢

2024-04-27 03:36:10 发布

您现在位置:Python中文网/ 问答频道 /正文

我想将大numpyndarray序列化为TFRecord。问题是,这个过程如果慢得令人痛苦。对于大小为(1000000,65)的数组,几乎需要一分钟的时间。将其序列化为其他二进制格式(HDF5、npy、拼花…)只需不到一秒钟。我很确定有一种更快的方法来序列化它,但我就是想不出来

import numpy as np
import tensorflow as tf

X = np.random.randn(1000000, 65)

def write_tf_dataset(data: np.ndarray, path: str):
    with tf.io.TFRecordWriter(path=path) as writer:
        for record in data:
            feature = {'X': tf.train.Feature(float_list=tf.train.FloatList(value=record[:42])),
                       'Y': tf.train.Feature(float_list=tf.train.FloatList(value=record[42:64])),
                       'Z': tf.train.Feature(float_list=tf.train.FloatList(value=[record[64]]))}
            example = tf.train.Example(features=tf.train.Features(feature=feature))
            serialized = example.SerializeToString()
            writer.write(serialized)

write_tf_dataset(X, 'X.tfrecord')

如何提高write_tf_dataset的性能?myX的大小比代码段中的大小大200倍

我不是第一个抱怨TFRecord性能缓慢的人。基于this Tensorflow Github issue,我制作了第二个版本的函数:

import pickle

def write_tf_dataset(data: np.ndarray, path: str):
    with tf.io.TFRecordWriter(path=path) as writer:
        for record in data:
            feature = {
                'X': tf.io.serialize_tensor(record[:42]).numpy(),
                'Y': tf.io.serialize_tensor(record[42:64]).numpy(),
                'Z': tf.io.serialize_tensor(record[64]).numpy(),
            }
            serialized = pickle.dumps(feature)
            writer.write(serialized)

。。。但如果表现更糟。想法


Tags: pathionumpydata序列化tfasnp
1条回答
网友
1楼 · 发布于 2024-04-27 03:36:10

解决方法是使用multiprocessing包。您可以将多个进程写入同一个TFRecord文件,或者将每个进程写入不同的文件(我认为建议使用多个(小)TFRecords,而不是单个(大)文件,因为从多个源读取更快):

import multiprocessing
import os

import numpy as np
import tensorflow as tf


def serialize_example(record):
    feature = {
        "X": tf.train.Feature(float_list=tf.train.FloatList(value=record[:42])),
        "Y": tf.train.Feature(float_list=tf.train.FloatList(value=record[42:64])),
        "Z": tf.train.Feature(float_list=tf.train.FloatList(value=[record[64]])),
    }
    example = tf.train.Example(features=tf.train.Features(feature=feature))
    return example.SerializeToString()


def write_tfrecord(tfrecord_path, records):
    with tf.io.TFRecordWriter(tfrecord_path) as writer:
        for item in records:
            serialized = serialize_example(item)
            writer.write(serialized)


if __name__ == "__main__":
    np.random.seed(1234)
    data = np.random.randn(1000000, 65)

    # Option 1: write to a single file
    tfrecord_path = "/home/appuser/data/data.tfrecord"
    p = multiprocessing.Pool(4)
    with tf.io.TFRecordWriter(tfrecord_path) as writer:
        for example in p.map(serialize_example, data):
            writer.write(example)

    # Option 2: write to multiple files
    procs = []
    n_shard = 4
    num_per_shard = int(np.ceil(len(data) / n_shard))
    for shard_id in range(n_shard):
        filename = f"data_{shard_id + 1:04d}_of_{n_shard:04d}.tfrecord"
        tfrecord_path = os.path.join("/home/appuser/data", filename)

        start_index = shard_id * num_per_shard
        end_index = min((shard_id + 1) * num_per_shard, len(data))

        args = (tfrecord_path, data[start_index:end_index])
        p = multiprocessing.Process(target=write_tfrecord, args=args)
        p.start()
        procs.append(p)

    for proc in procs:
        proc.join()

相关问题 更多 >