如何从十进制的字符串张量中读取数据集名称

2024-05-16 09:55:30 发布

您现在位置:Python中文网/ 问答频道 /正文

我是tensorflow的新手,我有一个张量(字符串类型),我在其中存储了所有需要的图像的图像路径,我想用这些图像来训练一个模型。在

问题:如何读取张量到队列,然后批处理它。

我的方法是:给我错误

    img_names = dataset['f0']
    file_length = len(img_names)
    type(img_names)
    tf_img_names = tf.stack(img_names)
    filename_queue = tf.train.string_input_producer(tf_img_names, num_epochs=num_epochs, shuffle=False)
    wd=getcwd()
    print('In input pipeline')
    tf_img_queue = tf.FIFOQueue(file_length,dtypes=[tf.string])
    col_Image = tf_img_queue.dequeue(filename_queue)
    ### Read Image
    img_file = tf.read_file(wd+'/'+col_Image)
    image = tf.image.decode_png(img_file, channels=num_channels)
    image = tf.cast(image, tf.float32) / 255.
    image = tf.image.resize_images(image,[image_width, image_height])
    min_after_dequeue = 100
    capacity = min_after_dequeue + 3 * batch_size
    image_batch, label_batch = tf.train.batch([image, onehot], batch_size=batch_size, capacity=capacity, allow_smaller_final_batch = True, min_after_dequeue=min_after_dequeue)

错误:类型错误:应为字符串或缓冲区'

我不知道我的方法是否正确


Tags: 图像imageimgnamesqueuetf错误batch
1条回答
网友
1楼 · 发布于 2024-05-16 09:55:30

您不必创建另一个队列。您可以定义一个读卡器,它将为您出列元素。你可以试试下面的方法并评论一下。在

reader = tf.IdentityReader()
key, value = reader.read(filename_queue)
dir = tf.constant(wd)
path = tf.string_join([dir,tf.constant("/"),value])
img_file = tf.read_file(path)

为了检查你的路径是否正确,请

^{pr2}$

寻找你的反馈。在

相关问题 更多 >