我注意到有一个preprocess_input
函数,它根据您想在tensorflow.keras.applications
中使用的模型而不同
我正在使用ImageDataGenerator
类来扩充我的数据。更具体地说,我使用的是CustomDataGenerator
,它从ImageDataGenerator
类扩展而来,并添加了颜色转换
这就是它的样子:
class CustomDataGenerator(ImageDataGenerator):
def __init__(self, color=False, **kwargs):
super().__init__(preprocessing_function=self.augment_color, **kwargs)
self.hue = None
if color:
self.hue = random.random()
def augment_color(self, img):
if not self.hue or random.random() < 1/3:
return img
img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
img_hsv[:, :, 0] = self.hue
return cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR)
我在ImageDataGenerator
上使用了rescale=1./255
,但是有些模型需要不同的预处理
所以当我尝试
CustomDataGenerator(preprocessing_function=tf.keras.applications.xception.preprocess_input)
我得到这个错误:
__init__() got multiple values for keyword argument 'preprocessing_function'
问题是,您已经在这里传递了
preprocessing_function
然后又把它从
所以现在就像
去掉其中一个,你就可以走了
编辑1: 如果您想保留这两种方法,最好将它们合并到一个预处理方法中,并将其作为预处理函数传递
将以下方法添加到CustomDataGenerator
将其用作预处理函数
相关问题 更多 >
编程相关推荐