我已经用输入(图像)大小[None,400400,3]训练了我的模型,但是我想用不同的输入大小进行测试,比如[None,512512,3]。以下是我的自定义培训实施:
my_model = customModel(rgb_mean=self.args.rgbn_mean)
ckpt_manager = tf.train.Checkpoint(optimizer=optimizer,model=my_model)
for epoch in range(self.args.max_epochs):
# training
for step, (x,y) in enumerate(train_data):
with tf.GradientTape() as tape:
pred = my_model(x, training=True)
preds, last_logits, loss = pre_process_binary_cross_entropy(
loss_bc,pred, y,self.args, use_tf_loss=False)
if (step)%100==0 and loss < global_loss:
# tfk.Model.save_weights(my_model,os.path.join(checkpoint_dir,"saved_model.h5"),
# save_format=ckpt_save_mode)
# # tfk.models.save_model(my_model,os.path.join(checkpoint_dir,"1saved_model.h5"),
# # save_format=ckpt_save_mode)
# tfk.models.save_model(my_model,checkpoint_dir)
ckpt_manager.save(checkpoint_dir)
现在,这里是我的自定义测试实现:
root = tf.train.Checkpoint(optimizer=optimizer,
model=my_model)
ckpt_manager = tf.train.CheckpointManager(root,checkpoit_dir,max_to_keep=10)
root.restore(ckpt_manager.latest_checkpoint)
for step, x in enumerate(test_data):
preds = my_model(x,training=False)
当我用400x400测试模型大小时,它工作得很好,但当我用其他尺寸(如512x512或720x1280)测试时,它会给出以下日志:
Traceback (most recent call last):
File "C:/Users/xavie/Documents/Codes/GitHub/efge/main.py", line 76, in <module>
main(args=arg)
File "C:/Users/xavie/Documents/Codes/GitHub/efge/main.py", line 70, in main
model.test()
File "C:\Users\xavie\Documents\Codes\GitHub\efge\run_model.py", line 198, in test
preds = my_model(x,training=False)
File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 968, in __call__
outputs = self.call(cast_inputs, *args, **kwargs)
File "C:\Users\xavie\Documents\Codes\GitHub\efge\model.py", line 90, in call
output = self.batchnorm1(output, training=training)
File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 964, in __call__
self._maybe_build(inputs)
File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 2416, in _maybe_build
self.build(input_shapes) # pylint:disable=not-callable
File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\keras\layers\normalization.py", line 401, in build
experimental_autocast=False)
File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 577, in add_weight
caching_device=caching_device)
File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\training\tracking\base.py", line 724, in _add_variable_with_custom_getter
name=name, shape=shape)
File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\training\tracking\base.py", line 791, in _preload_simple_restoration
checkpoint_position=checkpoint_position, shape=shape)
File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\training\tracking\base.py", line 75, in __init__
self.wrapped_value.set_shape(shape)
File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 1107, in set_shape
(self.shape, shape))
ValueError: Tensor's shape (200,) is not compatible with supplied shape (256,)
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm1.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm1.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm1.moving_mean
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm1.moving_variance
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.conv2.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm2.axis
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm2.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm2.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm2.moving_mean
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm2.moving_variance
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.conv3.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm3.axis
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm3.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm3.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm3.moving_mean
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm3.moving_variance
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.dconv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm4.axis
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm4.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm4.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm4.moving_mean
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm4.moving_variance
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.dconv2.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm5.axis
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm5.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm5.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm5.moving_mean
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm5.moving_variance
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.dconv3.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.predConv1.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.predConv1.conv1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.predConv2.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.predConv2.conv1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.predConv3.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.predConv3.conv1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm1.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm1.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.conv2.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm2.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm2.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.conv3.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm3.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm3.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.dconv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm4.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm4.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.dconv2.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm5.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm5.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.dconv3.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.predConv1.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.predConv1.conv1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.predConv2.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.predConv2.conv1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.predConv3.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.predConv3.conv1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm1.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm1.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.conv2.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm2.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm2.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.conv3.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm3.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm3.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.dconv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm4.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm4.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.dconv2.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm5.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm5.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.dconv3.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.predConv1.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.predConv1.conv1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.predConv2.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.predConv2.conv1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.predConv3.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.predConv3.conv1.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
模型如下:
class customModel(tfk.Model):
def __init__(self, data_format='channels_first', weight_decay=1e4, rgb_mean=None):
super(customModel, self).__init__()
self.weight_decay = weight_decay
self.rgbn_mean = rgb_mean
axis = -1 if data_format == "channels_last" else 1
# data_format=data_format,
self.conv1 = tfk.layers.Conv2D(filters=16, kernel_size=(7, 7),
padding="same", use_bias=False,
kernel_initializer=weight_init,
kernel_regularizer=l2(weight_decay),
strides=(2, 2)) # [8,200,200,16] when the input is 400
self.batchnorm1 = tfk.layers.BatchNormalization(axis=axis)
def call(self, x, training=False):
x = x-self.rgbn_mean[:-1]
output = self.conv1(x, training=training)
output = self.batchnorm1(output, training=training)
output = tf.nn.relu(output)
return output
我做错了什么?我怎样才能修好它?请帮助我,我是Keras的新手:( 附言:我试过使用不同的Keras save型号,但无法使用不同的图像大小进行测试
目前没有回答
相关问题 更多 >
编程相关推荐