TensorFlow dataset.map()方法不适用于内置的tf.keras.preprocessing.image函数

2024-03-29 14:59:08 发布

您现在位置:Python中文网/ 问答频道 /正文

我在数据集中加载如下内容:

import tensorflow_datasets as tfds

ds = tfds.load(
    'caltech_birds2010',
    split='train',
    as_supervised=False)

这个函数可以很好地工作:

import tensorflow as tf

@tf.function
def pad(image,label):
    return (tf.image.resize_with_pad(image,32,32),label)

ds = ds.map(pad)

但是当我尝试映射一个不同的内置函数时

from tf.keras.preprocessing.image import random_rotation

@tf.function
def rotate(image,label):
    return (random_rotation(image,90), label)

ds = ds.map(rotate)

我得到以下错误:

AttributeError: 'Tensor' object has no attribute 'ndim'

这不是唯一一个给我带来问题的函数,无论是否使用@tf.function装饰器,它都会发生

非常感谢您的帮助


Tags: 函数imageimportmapreturntftensorflowdef
1条回答
网友
1楼 · 发布于 2024-03-29 14:59:08

我会尝试在这里使用tf.py_函数进行随机_旋转。例如:

def rotate(image, label):
    im_shape = image.shape
    [image, label,] = tf.py_function(random_rotate,[image, label],
                                     [tf.float32, tf.string])
    image.set_shape(im_shape)
    return image, label

ds = ds.map(rotate)

尽管我认为根据What is the difference in purpose between tf.py_function and tf.function?,它们在这里做了类似的事情,但是tf.py_函数对于通过tensorflow执行python代码来说更为直接,尽管tf.function具有性能优势

相关问题 更多 >