我正在将TensorFlow代码迁移到TensorFlow 2.1.0
以下是原始代码:
conv = tf.layers.conv2d(inputs, out_channels, kernel_size=3, padding='SAME')
conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
conv = tf.nn.relu(conv)
conv = tf.contrib.layers.max_pool2d(conv, 2)
这就是我所做的:
conv1 = Conv2D(out_channels, (3, 3), activation='relu', padding='same', data_format='channels_last', name=name)(inputs)
conv1 = Conv2D(64, (5, 5), activation='relu', padding='same', data_format="channels_last")(conv1)
#conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
pool1 = MaxPooling2D(pool_size=(2, 2), data_format="channels_last")(conv1)
我的问题是我不知道如何处理tf.contrib.layers.batch_norm
如何将tf.contrib.layers.batch_norm
迁移到Tensorflow 2.x
更新:
使用评论建议,我认为我已正确迁移:
conv1 = BatchNormalization(momentum=0.99, scale=True, center=True)(conv1)
但是我不确定decay
是否像momentum
,我不知道如何在BatchNormalization
方法中设置updates_collections
我在使用一个经过训练的模型时遇到了这个问题,我打算对这个模型进行微调。就像OP那样用
tf.keras.layers.BatchNormalization
替换tf.contrib.layers.batch_norm
给了我一个错误,其修复方法如下所述旧代码如下所示:
更新后的工作代码如下所示:
我不确定我删除的所有关键字参数是否都会有问题,但似乎一切都正常。请注意
name="BatchNorm"
参数。这些层使用不同的命名模式,因此我必须使用inspect_checkpoint.py
工具查看模型,并找到恰好是BatchNorm
的层名称相关问题 更多 >
编程相关推荐