如何保存包含keras模型的对象?

2024-04-26 10:40:18 发布

您现在位置:Python中文网/ 问答频道 /正文

下面是我的代码框架:

def build_model(x, y):
    model = tf.keras.models.Sequential()
    model.add(tf.keras.layers.Dense(1, activation='relu'))
    model.compile(loss='mean_squared_error', optimizer='adam')
    model.fit(x, y)
    return model

class MultiModel(object):
    def __init__(self, x1, y1, x2, y2):
        self.model1 = build_model(x1, y1)
        self.model2 = build_model(x2, y2)

if __name__ == '__main__':
    # code that builds x1, x2, y1, y2
    mm = MultiModel(x1, y1, x2, y2)  # How to save mm ?

问题是我不知道如何保存包含多个Keras模型的mm对象。在

Keras内置的save方法只允许保存Keras模型,因此在这种情况下它不可用。 pickle模块无法保存_螺纹锁紧对象,因此它也不可用。在

也许可以使用Keras save方法独立保存每个模型,然后重新组合它们并将它们作为一个整体保存。但我不知道该怎么办。在

有什么想法吗?在


Tags: 模型buildselfmodelsavetfdefkeras