如何使用tf.estim导入meta图来训练/优化模型构建图形

2024-06-09 01:40:07 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在调整一个由tf.估计量,但是用import\u meta\u graph来构建图形,而不是用代码。 如何使用多GPU进行微调?你知道吗

我想要的是:

  • 我有一个由tf.estimator训练的模型。你知道吗
def model_fn_original(feature, labels, mode, params):
  input = tf.layers.xxx
  input = tf.layers.xxx
  logits = tf.layers.dense(input)
  if mode == tf.estimator.ModeKeys.TRAIN:
    return tf.estimator.EstimatorSpec()
  • 2-使用export_meta_graph保存正向图
def export_meta_graph():
  input = tf.layers.xxx
  input = tf.layers.xxx
  logits = tf.layers.dense(input)
  tf.export_meta_graph(meta_file_name)

  • 3-使用import_meta_graph方法替换model\u fn中的build graph代码,用多个gpu对模型进行微调
def model_fn_imported(feature, labels, mode, params):
  tf.import_meta_graph(meta_file_name)
  logits = tf.get_defalt_graph().get_tensor_by_name("dense/BiasAdd:0")
  if mode == tf.estimator.ModeKeys.TRAIN:
    return tf.estimator.EstimatorSpec()

estimator = tf.estimator.Estiator(model_fn=model_fn_imported, ...)
estimator.train()

我已经测试了export meta\u graph(mode=eval)first build graph byimport_meta_graph。 评价成功,准确度与原准确度一致。你知道吗

当微调它,如果我只使用一个GPU,一切正常。你知道吗

但当我使用多GPU时:

  • 1-使用tf.contrib.estimator.replicate_model_fn

    出现以下错误: 在检查点中找不到1号塔/型号/xxx/xxx/xxx

在网上调查之后,这应该是因为import_meta_graph的变量不能被设置重用。所以当运行replicate_model_fnmodel_fn many次时,每次都要创建新变量。你知道吗

  • 2-使用tf.contrib.分配.镜像策略:

    出现一些错误:目标必须是DistributedValues对象中的一个tf.变量对象设备字符串设备字符串列表

所以问题是,在使用import\u meta\u graph构建图形时,有没有一种方法可以使用多个GPU对模型进行微调/训练。

理论上可以吗?


Tags: importinputmodelgpumodelayerstfdef