如何在tensorflow\u数据集中加速将tensor转换成numpy数组的代码?

2024-04-26 13:45:24 发布

您现在位置:Python中文网/ 问答频道 /正文

尽管我想在tensorflow\u数据集中将tensor转换为numpy数组,但我的代码的速度会逐渐减慢。 现在,我使用的lsun/卧室数据集有超过300万张图像。 如何加速我的代码?你知道吗

我的代码保存tuple,它每100000个图像就有一个numpy数组。你知道吗

train_tf = tfds.load("lsun/bedroom", data_dir="{$my_directory}", download=False)
train_tf = train_tf["train"]
for data in train_tf:
    if d_cnt==0 and d_cnt%100001==0:
        train = (tfds.as_numpy(data["image"]), )
    else:
        train += (tfds.as_numpy(data["image"]), )

    if d_cnt%100000==0 and d_cnt!=0:
        with open("{$my_directory}/lsun.pickle%d"%(d_cnt), "wb") as f:
            pickle.dump(train, f)

    d_cnt += 1

Tags: 数据代码图像numpydataifmytf
1条回答
网友
1楼 · 发布于 2024-04-26 13:45:24

您的if条件永远不会在第一次传递之后执行,因此您的train变量会不断累积。你知道吗

我想你希望有这样的条件:

if d_cnt!=0 and d_cnt%100001==0:
    train = (tfds.as_numpy(data["image"]), )

相关问题 更多 >