在训练/使用张量索引期间将非传感器参数传递给Keras模型

2024-05-29 10:14:22 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试训练一个Keras模型,该模型将数据增强纳入模型本身。模型的输入是不同类的图像,模型应该为每个类生成一个扩充模型,用于扩充过程。我的代码大致如下所示:

from keras.models import Model
from keras.layers import Input
...further imports...

def get_main_model(input_shape, n_classes):
    encoder_model = get_encoder_model()
    input = Input(input_shape, name="input")
    label_input = Input((1,), name="label_input")
    aug_models = [get_augmentation_model() for i in range(n_classes)]
    augmentation = aug_models[label_input](input)
    x = encoder_model(input)
    y = encoder_model(augmentation)
    model = Model(inputs=[input, label_input], outputs=[x, y])
    model.add_loss(custom_loss_function(x, y))
    return model 

然后,我希望通过模型传递成批数据,该模型由一个图像数组(传递给input)和一个相应的标签数组(传递给label_input)组成。但是,这不起作用,因为任何输入到label_input中的内容都会被Tensorflow转换为张量,并且不能用于下面的索引。我尝试了以下几点:

  • augmentation = aug_models[int(label_input)](input)-->;不起作用 因为label_input is a tensor
  • augmentation = aug_models[tf.make_ndarray(label_input)](input)-->;强制转换不起作用(我想是因为label_输入是一个符号张量)
  • tf.gather(aug_models, label_input)-->;不起作用,因为操作的结果是Tensorflow试图将其转换为张量的Keras模型实例(显然失败)

Tensorflow中是否有任何技巧可以让我在训练期间向模型传递一个未转换为张量的参数,或者以不同的方式告诉模型选择哪个增强模型?提前谢谢


Tags: 数据模型图像gtencoderinputgetmodel
1条回答
网友
1楼 · 发布于 2024-05-29 10:14:22

要对input张量的每个元素应用不同的增广(例如,以label_input为条件),您需要:

  1. 首先,为批处理的每个元素计算每个可能的扩充
  2. 其次,根据标签选择所需的增强

不幸的是,索引是不可能的,因为inputlabel_input张量都是多维的(例如,如果要对批处理的每个元素应用相同的扩充,那么就可以使用任何条件tensorflow语句,例如tf.case


下面是一个简单的工作示例,展示了如何实现这一点:

input = tf.ones((3, 1))  # Shape=(bs, 1)
label_input = tf.constant([3, 2, 1])  # Shape=(bs, 1)
aug_models = [lambda x: x, lambda x: x * 2, lambda x: x * 3, lambda x: x * 4]
nb_classes = len(aug_models)

augmented_data = tf.stack([aug_model(input) for aug_model in aug_models])  # Shape=(nb_classes, bs, 1)
selector = tf.transpose(tf.one_hot(label_input, depth=nb_classes))  # Shape=(nb_classes, bs)
augmentation = tf.reduce_sum(selector[..., None] * augmented_data, axis=0)  # Shape=(bs, 1) 
print(augmentation)

# prints:
# tf.Tensor(
# [[4.]
#  [3.]
#  [2.]], shape=(3, 1), dtype=float32)

注意:您可能需要将这些操作包装到KerasLambda layer

相关问题 更多 >

    热门问题