如何从tensorflow和yoloV3中的协议缓冲区文件恢复训练?

2024-04-19 10:23:45 发布

您现在位置:Python中文网/ 问答频道 /正文

我可以保存ckpt文件,冻结pb文件中的图形,并使用它来做测试形象。但是现在,我想从PB文件恢复训练,对于C++。所以我有两个子问题,

  1. 如何获取用于训练的所有操作节点名称。你知道吗
  2. 如何从pbfile恢复训练。你知道吗

使用C++中的PB文件,可以保存PB文件进行测试。你知道吗

我所有的代码都在这里:https://github.com/YunYang1994/tensorflow-yolov3

使用转换中的代码_重量.py获取用于保存pb文件的节点名。你知道吗

但它打印的错误提示是这样的“***不在图表中”

for var in tf.global_variables():
var_name = var.op.name
var_name_mess = str(var_name).split('/')
var_shape = var.shape
print("111111111111111111112222222222222222222=> ")
print(var_name_mess[0])
if flag.train_from_coco:
    if var_name_mess[0] in preserve_cur_names: continue
cur_weights_mess.append([var_name, var_shape])
org_weights_num = len(org_weights_mess)
cur_weights_num = len(cur_weights_mess)
if cur_weights_num != org_weights_num:
raise RuntimeError

print('=> Number of weights that will rename:\t%d' % cur_weights_num)
cur_to_org_dict = {}
for index in range(org_weights_num):
org_name, org_shape = org_weights_mess[index]
cur_name, cur_shape = cur_weights_mess[index]
if cur_shape != org_shape:
    print(org_weights_mess[index])
    print(cur_weights_mess[index])
    raise RuntimeError
cur_to_org_dict[cur_name] = org_name
print("3333=> " + str(cur_name).ljust(50) + ' : ' + org_name)

with tf.name_scope('load_save'):
    name_to_var_dict = {var.op.name: var for var in         
tf.global_variables()}
restore_dict = {cur_to_org_dict[cur_name]: name_to_var_dict[cur_name] for         
cur_name in cur_to_org_dict}
load = tf.train.Saver(restore_dict)
save = tf.train.Saver(tf.global_variables())
for var in tf.global_variables():
    print("44444=> " + var.op.name)

你能帮我在python代码中筛选有用的节点名以及如何恢复训练吗。你知道吗


Tags: 文件tonameinorgforvartf