如何从张量元组创建TF数据集?(和最佳做法)

2024-04-18 18:18:06 发布

您现在位置:Python中文网/ 问答频道 /正文

我在创建自定义数据集时遇到一些问题。 现在我设法得到了一个生成器函数:

def dataset_generator():
for path in pathlist_img:
    img, labels = process_path(path)
    yield img, labels 

它返回两个tensorflow张量:第一个是shape=(7201280,3),dtype=uint8,第二个是shape=(?,14),dtype=float32,其中“?”表示它取决于图像(它是一个对象检测数据集,因此识别的实例数不是固定的)

我希望有一个与我的标签相关联的图像数据集,这就是为什么我会产生元组

问题是我的数据集

dataset = tf.data.Dataset.from_generator(dataset_generator, ((tf.uint8, tf.uint8, tf.uint8),(tf.float32, tf.float32)))

这只是一个小问题

<FlatMapDataset shapes: ((<unknown>, <unknown>, <unknown>), (<unknown>, <unknown>)), types: ((tf.uint8, tf.uint8, tf.uint8), (tf.float32, tf.float32))>

而且似乎什么都没有,或者至少我不能从中得到任何东西

是否有一些从图像和标签文件构建数据集的最佳实践? 我怎样才能解决这个问题


Tags: 数据path图像imglabelstf标签generator
1条回答
网友
1楼 · 发布于 2024-04-18 18:18:06

尝试这样重写:

img = tf.ones((730, 1280, 3), dtype=tf.uint8)
label = tf.ones((tf.random.uniform(shape=[], minval=1, maxval=10), 14),
                dtype=tf.float32)
def gen():
    yield img,label

dataset = tf.data.Dataset.from_generator(
      gen,
      (tf.uint8, tf.float32),
)

在输出类型中,您应该只指定总体张量的总体数据类型,而不是每个维度的数据类型

相关问题 更多 >