如何恢复和执行两个独立的TensorFlow模型?

2024-04-28 06:34:17 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在开发一个基于TensorFlow的程序,需要两个不同的模型。根据第一个模型的输出,我执行了一些计算,并得到了第二个模型的输入。以下是部分代码:

key_pose_model = None
gesture_model = None


# Key Poses
with tf.device('/cpu:0'):
    if key_pose_model_type == 'ramdon_forest':
        key_pose_model = RandomForest(key_pose_ramdon_forest_num_steps,
                                key_pose_ramdon_forest_num_classes,
                                key_pose_ramdon_forest_num_trees,
                                key_pose_ramdon_forest_max_nodes,
                                key_pose_ramdon_forest_num_features)
        key_pose_model.read_model(key_pose_model_name)


# Gestures
with tf.device('/cpu:1'):
    if gesture_model_type == 'ramdon_forest':
        gesture_model = RandomForest(gesture_ramdon_forest_num_steps,
                                gesture_ramdon_forest_num_classes,
                                gesture_ramdon_forest_num_trees,
                                gesture_ramdon_forest_max_nodes,
                                gesture_ramdon_forest_num_features)
        gesture_model.read_model(gesture_model_name)

之后,我的代码中有以下调用(输入数据来自传感器):

while(True):
......
......
key_pose_model.prediction(input_data_x)
......
......
......
......
gesture_model.prediction(input_data_x_1)
.......

它对第一个模型很有效,然后当我恢复第二个模型时,我有重复变量的错误,所以我认为我没有使用另一个图。我正在阅读TensorFlow文档,我试图复制关于不同会话的示例,但我做不到

g_1 = tf.Graph()
with g_1.as_default():
  # Operations created in this scope will be added to `g_1`.
  c = tf.constant("Node in g_1")

  # Sessions created in this scope will run operations from `g_1`.
  sess_1 = tf.Session()

g_2 = tf.Graph()
with g_2.as_default():
  # Operations created in this scope will be added to `g_2`.
  d = tf.constant("Node in g_2")

# Alternatively, you can pass a graph when constructing a <a href="./../api_docs/python/tf/Session"><code>tf.Session</code></a>:
# `sess_2` will run operations from `g_2`.
sess_2 = tf.Session(graph=g_2)

assert c.graph is g_1
assert sess_1.graph is g_1

assert d.graph is g_2
assert sess_2.graph is g_2

先谢谢你


Tags: keyin模型modelsessiontfwithwill