路缘石替换输入层

2024-04-25 19:49:24 发布

您现在位置:Python中文网/ 问答频道 /正文

我拥有的代码(我无法更改)使用Resnet和my_input_tensor作为输入张量。

model1 = keras.applications.resnet50.ResNet50(input_tensor=my_input_tensor, weights='imagenet')

ResNet50函数研究source code,用my_input_tensor创建一个新的keras输入层,然后创建模型的其余部分。这是我想用我自己的模型复制的行为。我从h5文件加载我的模型。

model2 = keras.models.load_model('my_model.h5')

由于这个模型已经有了一个输入层,我想用一个新的输入层替换它,这个输入层是用my_input_tensor定义的。

如何替换输入层?


Tags: 代码模型inputmodelmykerash5tensor
3条回答

使用以下命令保存模型时:

old_model.save('my_model.h5')

它将保存以下内容:

  1. 模型的体系结构,允许创建模型。
  2. 模型的权重。
  3. 模型的训练配置(loss,optimizer)。
  4. 优化器的状态,允许从以前离开的位置恢复训练。

因此,加载模型时:

res50_model = load_model('my_model.h5')

您应该取回相同的模型,可以使用以下方法验证相同的模型:

res50_model.summary()
res50_model.get_weights()

现在,您可以弹出输入层并使用以下命令添加您自己的:

res50_model.layers.pop(0)
res50_model.summary()

添加新输入层:

newInput = Input(batch_shape=(0,299,299,3))    # let us say this new InputLayer
newOutputs = res50_model(newInput)
newModel = Model(newInput, newOutputs)

newModel.summary()
res50_model.summary()

不幸的是,@MilindDeore的解决方案对我不起作用。虽然我可以打印新模型的摘要,但我在预测时收到一个“矩阵大小不兼容”的错误。我想这是有意义的,因为密集层的新输入形状与旧密集层权重的形状不匹配。

因此,这里有另一个解决方案。对我来说,关键是用“层”而不是“层”。后者似乎只返回一个副本。

import keras
import numpy as np

def get_model():
    old_input_shape = (20, 20, 3)
    model = keras.models.Sequential()
    model.add(keras.layers.Conv2D(9, (3, 3), padding="same", input_shape=old_input_shape))
    model.add(keras.layers.MaxPooling2D((2, 2)))
    model.add(keras.layers.Flatten())
    model.add(keras.layers.Dense(1, activation="sigmoid"))
    model.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(lr=0.0001), metrics=['acc'], )
    model.summary()
    return model

def change_model(model, new_input_shape=(None, 40, 40, 3)):
    # replace input shape of first layer
    model._layers[1].batch_input_shape = new_input_shape

    # feel free to modify additional parameters of other layers, for example...
    model._layers[2].pool_size = (8, 8)
    model._layers[2].strides = (8, 8)

    # rebuild model architecture by exporting and importing via json
    new_model = keras.models.model_from_json(model.to_json())
    new_model.summary()

    # copy weights from old model to new one
    for layer in new_model.layers:
        try:
            layer.set_weights(model.get_layer(name=layer.name).get_weights())
        except:
            print("Could not transfer weights for layer {}".format(layer.name))

    # test new model on a random input image
    X = np.random.rand(10, 40, 40, 3)
    y_pred = new_model.predict(X)
    print(y_pred)

    return new_model

if __name__ == '__main__':
    model = get_model()
    new_model = change_model(model)

层数。pop(0)或类似的东西不起作用。

您可以尝试两个选项:

1.

可以使用所需图层创建新模型。

一个相对简单的方法是i)提取模型json配置,ii)适当地更改它,iii)从中创建一个新模型,然后iv)复制权重。我只展示基本的想法。

i)提取配置

model_config = model.get_config()

ii)更改配置

input_layer_name = model_config['layers'][0]['name']
model_config['layers'][0] = {
                      'name': 'new_input',
                      'class_name': 'InputLayer',
                      'config': {
                          'batch_input_shape': (None, 300, 300),
                          'dtype': 'float32',
                          'sparse': False,
                          'name': 'new_input'
                      },
                      'inbound_nodes': []
                  }
model_config['layers'][1]['inbound_nodes'] = [[['new_input', 0, 0, {}]]]
model_config['input_layers'] = [['new_input', 0, 0]]

ii)创建新模型

new_model = model.__class__.from_config(model_config, custom_objects={})  # change custom objects if necessary

ii)复制权重

# iterate over all the layers that we want to get weights from
weights = [layer.get_weights() for layer in model.layers[1:]]
for layer, weight in zip(new_model.layers[1:], weights):
    layer.set_weights(weight)

2.

您可以尝试像kerassurgeon这样的库(我链接到一个使用tensorflow keras版本的fork)。请注意,插入和删除操作仅在某些条件下工作,例如兼容的尺寸。

from kerassurgeon.operations import delete_layer, insert_layer

model = delete_layer(model, layer_1)
# insert new_layer_1 before layer_2 in a model
model = insert_layer(model, layer_2, new_layer_3)

相关问题 更多 >