不使用训练cod恢复TF-Eager模型

2024-06-08 09:47:51 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在以急切模式训练(并保存)一个非常简单的模型,如下所示:

import os
import tensorflow as tf
import tensorflow.contrib.eager as tfe

tf.enable_eager_execution()

NUM_EXAMPLES = 2000

training_inputs = tf.random_normal([NUM_EXAMPLES])
noise = tf.random_normal([NUM_EXAMPLES])
outputs = training_inputs * 3 + 2 + noise


class Model(tf.keras.Model):
    def __init__(self):
        super(Model, self).__init__()
        self.W = tfe.Variable(5., name="weight")
        self.b = tfe.Variable(0., name="bias")

    def predict(self, input):
        return self.W * input + self.b


def loss(model, inputs, outputs):
    error = model.predict(inputs) - outputs
    return tf.reduce_mean(tf.square(error))


def grad(model, inputs, outputs):
    with tf.GradientTape() as tape:
        loss_value = loss(model, inputs, outputs)
    return tape.gradient(loss_value, [model.W, model.b])


if __name__ == "__main__":
    model = Model()
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)

    for i in range(300):
        gradients = grad(model, training_inputs, outputs)
        optimizer.apply_gradients(zip(gradients, [model.W, model.b]),
                                  global_step=tf.train.get_or_create_global_step())

    checkpoint_dir = './checkpoints'
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

    root = tfe.Checkpoint(optimizer=optimizer,
                          model=model,
                          optimizer_step=tf.train.get_or_create_global_step())
    root.save(file_prefix=checkpoint_prefix)

我发现保存/还原的唯一方法(使用CheckpointSaver)意味着可以访问Model类以将其加载到其他地方,例如:

^{pr2}$

来自tf.keras.Modelsave方法似乎还没有在急切模式下实现:

model.save("keras_model")
>>> NotImplementedError

是否有另一种方法可以在不实例化新的Model对象的情况下保存和加载模型?在


Tags: importselfmodeltfdefasstepoutputs