如何将此代码从tensorflow 1升级到tensorflow 2

2024-05-16 07:10:21 发布

您现在位置:Python中文网/ 问答频道 /正文

我有以下代码:


    def _separable_conv(features, depth, kernel_size, depth_multiplier,
                        regularize_depthwise, rate, stride, scope):
      if activation_fn_in_separable_conv:
        activation_fn = tf.nn.relu
      else:
        activation_fn = None
        features = tf.nn.relu(features)
      return separable_conv2d_same(features,
                                   depth,
                                   kernel_size,
                                   depth_multiplier=depth_multiplier,
                                   stride=stride,
                                   rate=rate,
                                   activation_fn=activation_fn,
                                   regularize_depthwise=regularize_depthwise,
                                   scope=scope)
    for i in range(3):
      residual = _separable_conv(residual,
                                 depth_list[i],
                                 kernel_size=3,
                                 depth_multiplier=1,
                                 regularize_depthwise=regularize_depthwise,
                                 rate=rate*unit_rate_list[i],
                                 stride=stride if i == 2 else 1,
                                 scope='separable_conv' + str(i+1))
    if skip_connection_type == 'conv':
      shortcut = tf.Conv2D(inputs,
                             depth_list[-1],
                             [1, 1],
                             stride=stride,
                             activation_fn=None,
                             scope='shortcut')
      outputs = residual + shortcut
    elif skip_connection_type == 'sum':
      outputs = residual + inputs
    elif skip_connection_type == 'none':
      outputs = residual
    else:
      raise ValueError('Unsupported skip connection type.')

    return slim.utils.collect_named_outputs(outputs_collections,
                                            sc.name,
                                            outputs)

在最后一行中,我们使用了tf.contrib中的slim模块,它在tensorflow 2中被弃用。tensorflow 2中存在哪些函数或其他函数来执行与slim.utils.collect\u命名的\u输出行相同的功能


Tags: ratetfoutputsactivationfnfeaturesscopedepth
1条回答
网友
1楼 · 发布于 2024-05-16 07:10:21

TF slim现在在Github上作为一个外部包提供,它支持Tensorflow 2。该库具有相同的功能(包括此方法!),它只是有一个新的主页和不同的安装方式

核心库中没有可以直接替换代码的Tensorflow 2代码

相关问题 更多 >