如何在Keras中设置自定义测试步骤?

2024-04-25 21:29:15 发布

您现在位置:Python中文网/ 问答频道 /正文

我已经用输入(图像)大小[None,400400,3]训练了我的模型,但是我想用不同的输入大小进行测试,比如[None,512512,3]。以下是我的自定义培训实施:

    my_model = customModel(rgb_mean=self.args.rgbn_mean)        
    ckpt_manager = tf.train.Checkpoint(optimizer=optimizer,model=my_model)        
    for epoch in range(self.args.max_epochs):

        # training
        for step, (x,y) in enumerate(train_data):

            with tf.GradientTape() as tape:
                pred = my_model(x, training=True)

                preds, last_logits, loss = pre_process_binary_cross_entropy(
                    loss_bc,pred, y,self.args, use_tf_loss=False)
            if (step)%100==0 and loss < global_loss:
                # tfk.Model.save_weights(my_model,os.path.join(checkpoint_dir,"saved_model.h5"),
                #                        save_format=ckpt_save_mode)
                # # tfk.models.save_model(my_model,os.path.join(checkpoint_dir,"1saved_model.h5"),
                # #                        save_format=ckpt_save_mode)
                # tfk.models.save_model(my_model,checkpoint_dir)
                ckpt_manager.save(checkpoint_dir)

现在,这里是我的自定义测试实现:

         root = tf.train.Checkpoint(optimizer=optimizer,
                                   model=my_model)
        ckpt_manager = tf.train.CheckpointManager(root,checkpoit_dir,max_to_keep=10)
        root.restore(ckpt_manager.latest_checkpoint)
        for step, x in enumerate(test_data):
            preds = my_model(x,training=False)

当我用400x400测试模型大小时,它工作得很好,但当我用其他尺寸(如512x512或720x1280)测试时,它会给出以下日志:

Traceback (most recent call last):
  File "C:/Users/xavie/Documents/Codes/GitHub/efge/main.py", line 76, in <module>
    main(args=arg)
  File "C:/Users/xavie/Documents/Codes/GitHub/efge/main.py", line 70, in main
    model.test()
  File "C:\Users\xavie\Documents\Codes\GitHub\efge\run_model.py", line 198, in test
    preds = my_model(x,training=False)
  File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 968, in __call__
    outputs = self.call(cast_inputs, *args, **kwargs)
  File "C:\Users\xavie\Documents\Codes\GitHub\efge\model.py", line 90, in call
    output = self.batchnorm1(output, training=training)
  File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 964, in __call__
    self._maybe_build(inputs)
  File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 2416, in _maybe_build
    self.build(input_shapes)  # pylint:disable=not-callable
  File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\keras\layers\normalization.py", line 401, in build
    experimental_autocast=False)
  File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 577, in add_weight
    caching_device=caching_device)
  File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\training\tracking\base.py", line 724, in _add_variable_with_custom_getter
    name=name, shape=shape)
  File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\training\tracking\base.py", line 791, in _preload_simple_restoration
    checkpoint_position=checkpoint_position, shape=shape)
  File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\training\tracking\base.py", line 75, in __init__
    self.wrapped_value.set_shape(shape)
  File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 1107, in set_shape
    (self.shape, shape))
ValueError: Tensor's shape (200,) is not compatible with supplied shape (256,)
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm1.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm1.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm1.moving_mean
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm1.moving_variance
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.conv2.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm2.axis
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm2.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm2.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm2.moving_mean
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm2.moving_variance
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.conv3.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm3.axis
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm3.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm3.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm3.moving_mean
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm3.moving_variance
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.dconv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm4.axis
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm4.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm4.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm4.moving_mean
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm4.moving_variance
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.dconv2.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm5.axis
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm5.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm5.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm5.moving_mean
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm5.moving_variance
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.dconv3.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.predConv1.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.predConv1.conv1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.predConv2.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.predConv2.conv1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.predConv3.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.predConv3.conv1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm1.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm1.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.conv2.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm2.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm2.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.conv3.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm3.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm3.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.dconv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm4.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm4.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.dconv2.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm5.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm5.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.dconv3.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.predConv1.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.predConv1.conv1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.predConv2.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.predConv2.conv1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.predConv3.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.predConv3.conv1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm1.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm1.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.conv2.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm2.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm2.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.conv3.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm3.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm3.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.dconv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm4.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm4.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.dconv2.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm5.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm5.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.dconv3.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.predConv1.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.predConv1.conv1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.predConv2.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.predConv2.conv1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.predConv3.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.predConv3.conv1.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.

模型如下:

class customModel(tfk.Model):

    def __init__(self, data_format='channels_first', weight_decay=1e4, rgb_mean=None):

        super(customModel, self).__init__()
        self.weight_decay = weight_decay
        self.rgbn_mean = rgb_mean

        axis = -1 if data_format == "channels_last" else 1
        #  data_format=data_format,
        self.conv1 = tfk.layers.Conv2D(filters=16, kernel_size=(7, 7),
                                       padding="same", use_bias=False,
                                       kernel_initializer=weight_init,
                                       kernel_regularizer=l2(weight_decay),
                                       strides=(2, 2))  # [8,200,200,16] when the input is 400
        self.batchnorm1 = tfk.layers.BatchNormalization(axis=axis)

    def call(self, x, training=False):
        x = x-self.rgbn_mean[:-1]
        output = self.conv1(x, training=training)
        output = self.batchnorm1(output, training=training)
        output = tf.nn.relu(output)
        return output

我做错了什么?我怎样才能修好它?请帮助我,我是Keras的新手:( 附言:我试过使用不同的Keras save型号,但无法使用不同的图像大小进行测试


Tags: inselfformodelobjecttensorflowrootkernel