我正在尝试训练一个Keras模型,该模型将数据增强纳入模型本身。模型的输入是不同类的图像,模型应该为每个类生成一个扩充模型,用于扩充过程。我的代码大致如下所示:
from keras.models import Model
from keras.layers import Input
...further imports...
def get_main_model(input_shape, n_classes):
encoder_model = get_encoder_model()
input = Input(input_shape, name="input")
label_input = Input((1,), name="label_input")
aug_models = [get_augmentation_model() for i in range(n_classes)]
augmentation = aug_models[label_input](input)
x = encoder_model(input)
y = encoder_model(augmentation)
model = Model(inputs=[input, label_input], outputs=[x, y])
model.add_loss(custom_loss_function(x, y))
return model
然后,我希望通过模型传递成批数据,该模型由一个图像数组(传递给input
)和一个相应的标签数组(传递给label_input
)组成。但是,这不起作用,因为任何输入到label_input中的内容都会被Tensorflow转换为张量,并且不能用于下面的索引。我尝试了以下几点:
augmentation = aug_models[int(label_input)](input)
-->;不起作用
因为label_input is a tensor
augmentation = aug_models[tf.make_ndarray(label_input)](input)
-->;强制转换不起作用(我想是因为label_输入是一个符号张量)tf.gather(aug_models, label_input)
-->;不起作用,因为操作的结果是Tensorflow试图将其转换为张量的Keras模型实例(显然失败)Tensorflow中是否有任何技巧可以让我在训练期间向模型传递一个未转换为张量的参数,或者以不同的方式告诉模型选择哪个增强模型?提前谢谢
要对
input
张量的每个元素应用不同的增广(例如,以label_input
为条件),您需要:不幸的是,索引是不可能的,因为
input
和label_input
张量都是多维的(例如,如果要对批处理的每个元素应用相同的扩充,那么就可以使用任何条件tensorflow语句,例如tf.case)下面是一个简单的工作示例,展示了如何实现这一点:
注意:您可能需要将这些操作包装到KerasLambda layer中
相关问题 更多 >
编程相关推荐